diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index d7e15377d42e6161ed0a5d3cf96dfc1d2046a4b4..f8b22ee70331095e6e69bc071799be3cf4f44893 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -28,6 +28,26 @@ jobs: working-directory: backend run: make test-integration + frontend: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + cache: 'pnpm' + cache-dependency-path: frontend/pnpm-lock.yaml + - name: Install frontend dependencies + working-directory: frontend + run: pnpm install --frozen-lockfile + - name: Frontend typecheck and critical vitest + run: make test-frontend + golangci-lint: runs-on: ubuntu-latest steps: @@ -46,4 +66,4 @@ jobs: with: version: v2.9 args: --timeout=30m - working-directory: backend \ No newline at end of file + working-directory: backend diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml new file mode 100644 index 0000000000000000000000000000000000000000..67c8d6e9800b748bf179f67382bbe65ff8b1ebbb --- /dev/null +++ b/.github/workflows/cla.yml @@ -0,0 +1,59 @@ +name: "CLA Assistant" + +on: + issue_comment: + types: [created] + pull_request_target: + types: [opened, reopened, closed, synchronize] + +permissions: + actions: write + contents: write + pull-requests: write + statuses: write + +jobs: + cla-check: + if: | + github.event_name == 'issue_comment' || + (github.event_name == 'pull_request_target' && github.event.action != 'closed') + runs-on: ubuntu-latest + steps: + - name: "CLA Assistant" + if: | + (github.event.comment.body == 'recheck' || + github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || + github.event_name == 'pull_request_target' + uses: contributor-assistant/github-action@v2.6.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + path-to-signatures: "cla.json" + path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md" + branch: "cla-signatures" + allowlist: "dependabot[bot],renovate[bot],bot*" + lock-pullrequest-aftermerge: false + custom-notsigned-prcomment: | + Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md). + + **To sign**, please reply with the following comment: + + > I have read the CLA Document and I hereby sign the CLA + + You only need to sign once — it will be valid for all your future contributions to this project. + custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA" + custom-allsigned-prcomment: "All contributors have signed the CLA. ✅" + + cla-lock: + if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: "Lock merged PR" + uses: contributor-assistant/github-action@v2.6.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + path-to-signatures: "cla.json" + path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md" + branch: "cla-signatures" + lock-pullrequest-aftermerge: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b729c575ea0833b09d156cbe166147cae6ab1b1c..26ed8524141c94efc0786943f766f042f4da0659 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -246,10 +246,10 @@ jobs: if [ -n "$DOCKERHUB_USERNAME" ]; then DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api" MESSAGE+="# Docker Hub"$'\n' - MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n' + MESSAGE+="docker pull ${DOCKER_IMAGE}:${VERSION}"$'\n' MESSAGE+="# GitHub Container Registry"$'\n' fi - MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n' + MESSAGE+="docker pull ${GHCR_IMAGE}:${VERSION}"$'\n' MESSAGE+="\`\`\`"$'\n'$'\n' MESSAGE+="🔗 *相关链接:*"$'\n' MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n' diff --git a/.gitignore b/.gitignore index 1a92ea3e641316b8bc2f88def04cee30ec1377d5..bf7ee06411e607de6c5bbe153e69157a0ca91c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -129,9 +129,9 @@ vite.config.js docs/* !docs/PAYMENT.md !docs/PAYMENT_CN.md +!docs/ADMIN_PAYMENT_INTEGRATION_API.md .serena/ .codex/ frontend/coverage/ aicodex output/ - diff --git a/CLA.md b/CLA.md new file mode 100644 index 0000000000000000000000000000000000000000..ed0d74b818109005cbfc1e07368c22afd385899b --- /dev/null +++ b/CLA.md @@ -0,0 +1,73 @@ +# Sub2API Individual Contributor License Agreement (v1.0) + +Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project. + +By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project. + +## 1. Definitions + +- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement. +- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution." +- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns. + +## 2. Grant of Copyright License + +Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses. + +## 3. Moral Rights + +To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights. + +## 4. Grant of Patent License + +Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted. + +## 5. Representations and Warranties + +You represent and warrant that: + +(a) You are legally entitled to grant the above licenses. + +(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project. + +(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions. + +## 6. No Warranty + +Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support. + +## 7. No Obligation + +You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution. + +## 8. Retention of Rights + +You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above. + +## 9. Term and Termination + +This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable. + +## 10. Electronic Signature + +You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format. + +## 11. General Provisions + +**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter. + +**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect. + +**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision. + +**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission. + +**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt. + +--- + +**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.** + +To sign, reply in the pull request with: + +> I have read the CLA Document and I hereby sign the CLA diff --git a/LICENSE b/LICENSE index 7a94ca9dafe869cee12ff810464dd2bb0ce488ee..153d416dc8d2d60076698ec3cbfce34d91436a03 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,165 @@ -MIT License - -Copyright (c) 2025 Wesley Liddick - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. \ No newline at end of file diff --git a/Makefile b/Makefile index fd6a5a9a5f684ba283bce7ead247d9788f5eb8b0..d00d0c4f5ef525f1ca4efd7a2c4c236cf080be5d 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,12 @@ -.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan + +FRONTEND_CRITICAL_VITEST := \ + src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \ + src/views/auth/__tests__/WechatCallbackView.spec.ts \ + src/views/user/__tests__/PaymentView.spec.ts \ + src/views/user/__tests__/PaymentResultView.spec.ts \ + src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \ + src/views/admin/__tests__/SettingsView.spec.ts # 一键编译前后端 build: build-backend build-frontend @@ -24,6 +32,10 @@ test-backend: test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck + @$(MAKE) test-frontend-critical + +test-frontend-critical: + @pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST) test-datamanagementd: @cd datamanagement && go test ./... diff --git a/README.md b/README.md index 74ab9af258f3a4172eecc4a05d0035603d68d527..3e609d656d43d7f6a082d6779ed71edb4d5d0205 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF) + +bestproxy +Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control. + + ## Ecosystem @@ -618,7 +623,9 @@ sub2api/ ## License -MIT License +This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later). + +Copyright (c) 2026 Wesley Liddick --- diff --git a/README_CN.md b/README_CN.md index c701372c4393b6fe90b7eaba6da69b3308db91d6..add32a179b30253662705f3d48c939715b9a298b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -95,6 +95,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! + +bestproxy +感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。 + + ## 生态项目 @@ -679,7 +684,9 @@ sub2api/ ## 许可证 -MIT License +本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。 + +Copyright (c) 2026 Wesley Liddick --- diff --git a/README_JA.md b/README_JA.md index 0d4db616f08821d1191ae5ab64b14b16650c2af7..ccd595b93172e8a861fafc2e6aca47f89d35c6d2 100644 --- a/README_JA.md +++ b/README_JA.md @@ -95,6 +95,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを 本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます! + +bestproxy +Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。 + + ## エコシステム @@ -617,7 +622,9 @@ sub2api/ ## ライセンス -MIT License +本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。 + +Copyright (c) 2026 Wesley Liddick --- diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png new file mode 100644 index 0000000000000000000000000000000000000000..87c586705020a09a07484d0b0cb9f3bff79a219d Binary files /dev/null and b/assets/partners/logos/bestproxy.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index c29f5f750e5ade7fc32c39cfbd9d6c2638597609..8b06068853dff58e2d64fded24056aa618305888 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.114 +0.1.117 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 64709b5b1cd653e0a1c083f3e7a2105ec5ce6ecf..9bfa27174db6b7cc1e77fd10968c4091ac61d4ba 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -97,6 +97,7 @@ func provideCleanup( scheduledTestRunner *service.ScheduledTestRunnerService, backupSvc *service.BackupService, paymentOrderExpiry *service.PaymentOrderExpiryService, + channelMonitorRunner *service.ChannelMonitorRunner, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -239,6 +240,12 @@ func provideCleanup( } return nil }}, + {"ChannelMonitorRunner", func() error { + if channelMonitorRunner != nil { + channelMonitorRunner.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1d39fa1e31a9c8b69f8b512cd2a504953f2cfa1d..93270e7e1a713e63e1a4044a626e79cd2ab30896 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -61,8 +61,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client, db) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) + userRPMCache := repository.NewUserRPMCache(redisClient) userGroupRateRepository := repository.NewUserGroupRateRepository(db) + billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) @@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService, emailService, emailCache) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -104,7 +105,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -124,9 +125,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) + openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) - rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) + rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) @@ -136,7 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) @@ -174,7 +176,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() channelRepository := repository.NewChannelRepository(db) - channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) + channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) @@ -183,6 +185,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) + encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + if err != nil { + return nil, err + } + paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) + registry := payment.ProvideRegistry() + defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -210,18 +221,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) - registry := payment.ProvideRegistry() - encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + sqlDB, err := repository.ProvideSQLDB(client) if err != nil { return nil, err } - defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) - paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) + channelMonitorRepository := repository.NewChannelMonitorRepository(client, sqlDB) + channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, sqlDB) + channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository) + channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService) + channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor) + channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) + channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) + channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) + availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -233,7 +247,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -242,13 +256,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) - opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) + opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) + paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner) application := &Application{ Server: httpServer, Cleanup: v, @@ -302,6 +317,7 @@ func provideCleanup( scheduledTestRunner *service.ScheduledTestRunnerService, backupSvc *service.BackupService, paymentOrderExpiry *service.PaymentOrderExpiryService, + channelMonitorRunner *service.ChannelMonitorRunner, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -443,6 +459,12 @@ func provideCleanup( } return nil }}, + {"ChannelMonitorRunner", func() error { + if channelMonitorRunner != nil { + channelMonitorRunner.Stop() + } + return nil + }}, } infraSteps := []cleanupStep{ diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index a6e0551a382e35e462c73373899f5f713da6e962..5ccd67fb5cd2d4eb2a8e38edd78db35917ccc79b 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) @@ -76,6 +76,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { nil, // scheduledTestRunner nil, // backupSvc nil, // paymentOrderExpiry + nil, // channelMonitorRunner ) require.NotPanics(t, func() { diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go new file mode 100644 index 0000000000000000000000000000000000000000..5ccfcf19102d780646333f5f2da43d21ca8ce685 --- /dev/null +++ b/backend/ent/authidentity.go @@ -0,0 +1,266 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentity is the model entity for the AuthIdentity schema. +type AuthIdentity struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // VerifiedAt holds the value of the "verified_at" field. + VerifiedAt *time.Time `json:"verified_at,omitempty"` + // Issuer holds the value of the "issuer" field. + Issuer *string `json:"issuer,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityQuery when eager-loading is set. + Edges AuthIdentityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Channels holds the value of the channels edge. + Channels []*AuthIdentityChannel `json:"channels,omitempty"` + // AdoptionDecisions holds the value of the adoption_decisions edge. + AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// ChannelsOrErr returns the Channels value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) { + if e.loadedTypes[1] { + return e.Channels, nil + } + return nil, &NotLoadedError{edge: "channels"} +} + +// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) { + if e.loadedTypes[2] { + return e.AdoptionDecisions, nil + } + return nil, &NotLoadedError{edge: "adoption_decisions"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentity.FieldMetadata: + values[i] = new([]byte) + case authidentity.FieldID, authidentity.FieldUserID: + values[i] = new(sql.NullInt64) + case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer: + values[i] = new(sql.NullString) + case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentity fields. +func (_m *AuthIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentity.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case authidentity.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentity.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentity.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case authidentity.FieldVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field verified_at", values[i]) + } else if value.Valid { + _m.VerifiedAt = new(time.Time) + *_m.VerifiedAt = value.Time + } + case authidentity.FieldIssuer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field issuer", values[i]) + } else if value.Valid { + _m.Issuer = new(string) + *_m.Issuer = value.String + } + case authidentity.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryUser() *UserQuery { + return NewAuthIdentityClient(_m.config).QueryUser(_m) +} + +// QueryChannels queries the "channels" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery { + return NewAuthIdentityClient(_m.config).QueryChannels(_m) +} + +// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m) +} + +// Update returns a builder for updating this AuthIdentity. +// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne { + return NewAuthIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentity) Unwrap() *AuthIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentity) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.VerifiedAt; v != nil { + builder.WriteString("verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Issuer; v != nil { + builder.WriteString("issuer=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentities is a parsable slice of AuthIdentity. +type AuthIdentities []*AuthIdentity diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go new file mode 100644 index 0000000000000000000000000000000000000000..c90be759e827db02e0f8637d6a2bdfe559e10303 --- /dev/null +++ b/backend/ent/authidentity/authidentity.go @@ -0,0 +1,209 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentity type in the database. + Label = "auth_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldVerifiedAt holds the string denoting the verified_at field in the database. + FieldVerifiedAt = "verified_at" + // FieldIssuer holds the string denoting the issuer field in the database. + FieldIssuer = "issuer" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeChannels holds the string denoting the channels edge name in mutations. + EdgeChannels = "channels" + // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations. + EdgeAdoptionDecisions = "adoption_decisions" + // Table holds the table name of the authidentity in the database. + Table = "auth_identities" + // UserTable is the table that holds the user relation/edge. + UserTable = "auth_identities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // ChannelsTable is the table that holds the channels relation/edge. + ChannelsTable = "auth_identity_channels" + // ChannelsInverseTable is the table name for the AuthIdentityChannel entity. + // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package. + ChannelsInverseTable = "auth_identity_channels" + // ChannelsColumn is the table column denoting the channels relation/edge. + ChannelsColumn = "identity_id" + // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge. + AdoptionDecisionsTable = "identity_adoption_decisions" + // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionsInverseTable = "identity_adoption_decisions" + // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge. + AdoptionDecisionsColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldUserID, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldVerifiedAt, + FieldIssuer, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByVerifiedAt orders the results by the verified_at field. +func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc() +} + +// ByIssuer orders the results by the issuer field. +func ByIssuer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIssuer, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChannelsCount orders the results by channels count. +func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...) + } +} + +// ByChannels orders the results by channels terms. +func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAdoptionDecisionsCount orders the results by adoption_decisions count. +func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...) + } +} + +// ByAdoptionDecisions orders the results by adoption_decisions terms. +func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newChannelsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChannelsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) +} +func newAdoptionDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) +} diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go new file mode 100644 index 0000000000000000000000000000000000000000..3dbf317879b8813b2dee21acdda4f6122d477865 --- /dev/null +++ b/backend/ent/authidentity/where.go @@ -0,0 +1,600 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ. +func VerifiedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ. +func Issuer(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// VerifiedAtEQ applies the EQ predicate on the "verified_at" field. +func VerifiedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field. +func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtIn applies the In predicate on the "verified_at" field. +func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field. +func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtGT applies the GT predicate on the "verified_at" field. +func VerifiedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v)) +} + +// VerifiedAtGTE applies the GTE predicate on the "verified_at" field. +func VerifiedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v)) +} + +// VerifiedAtLT applies the LT predicate on the "verified_at" field. +func VerifiedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v)) +} + +// VerifiedAtLTE applies the LTE predicate on the "verified_at" field. +func VerifiedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v)) +} + +// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field. +func VerifiedAtIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt)) +} + +// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field. +func VerifiedAtNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt)) +} + +// IssuerEQ applies the EQ predicate on the "issuer" field. +func IssuerEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// IssuerNEQ applies the NEQ predicate on the "issuer" field. +func IssuerNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v)) +} + +// IssuerIn applies the In predicate on the "issuer" field. +func IssuerIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...)) +} + +// IssuerNotIn applies the NotIn predicate on the "issuer" field. +func IssuerNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...)) +} + +// IssuerGT applies the GT predicate on the "issuer" field. +func IssuerGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v)) +} + +// IssuerGTE applies the GTE predicate on the "issuer" field. +func IssuerGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v)) +} + +// IssuerLT applies the LT predicate on the "issuer" field. +func IssuerLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v)) +} + +// IssuerLTE applies the LTE predicate on the "issuer" field. +func IssuerLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v)) +} + +// IssuerContains applies the Contains predicate on the "issuer" field. +func IssuerContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v)) +} + +// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field. +func IssuerHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v)) +} + +// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field. +func IssuerHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v)) +} + +// IssuerIsNil applies the IsNil predicate on the "issuer" field. +func IssuerIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer)) +} + +// IssuerNotNil applies the NotNil predicate on the "issuer" field. +func IssuerNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer)) +} + +// IssuerEqualFold applies the EqualFold predicate on the "issuer" field. +func IssuerEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v)) +} + +// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field. +func IssuerContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChannels applies the HasEdge predicate on the "channels" edge. +func HasChannels() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates). +func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newChannelsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge. +func HasAdoptionDecisions() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates). +func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newAdoptionDecisionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go new file mode 100644 index 0000000000000000000000000000000000000000..e287705ce2af71def4c2140c10877347c24ff459 --- /dev/null +++ b/backend/ent/authidentity_create.go @@ -0,0 +1,1036 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityCreate is the builder for creating a AuthIdentity entity. +type AuthIdentityCreate struct { + config + mutation *AuthIdentityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetVerifiedAt sets the "verified_at" field. +func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetVerifiedAt(v) + return _c +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetVerifiedAt(*v) + } + return _c +} + +// SetIssuer sets the "issuer" field. +func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate { + _c.mutation.SetIssuer(v) + return _c +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate { + if v != nil { + _c.SetIssuer(*v) + } + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate { + return _c.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddChannelIDs(ids...) + return _c +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddAdoptionDecisionIDs(ids...) + return _c +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation { + return _c.mutation +} + +// Save creates the AuthIdentity in the database. +func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentity.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentity.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentity.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)} + } + return nil +} + +func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + _node.VerifiedAt = &value + } + if value, ok := _c.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + _node.Issuer = &value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne { + _c.conflict = opts + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityUpsertOne is the builder for "upsert"-ing + // one AuthIdentity node. + AuthIdentityUpsertOne struct { + create *AuthIdentityCreate + } + + // AuthIdentityUpsert is the "OnConflict" setter. + AuthIdentityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUpdatedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert { + u.Set(authidentity.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUserID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderSubject) + return u +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldVerifiedAt, v) + return u +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldVerifiedAt) + return u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldVerifiedAt) + return u +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldIssuer, v) + return u +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldIssuer) + return u +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldIssuer) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert { + u.Set(authidentity.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk. +type AuthIdentityCreateBulk struct { + config + err error + builders []*AuthIdentityCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentity entities in the database. +func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk { + _c.conflict = opts + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// AuthIdentityUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentity nodes. +type AuthIdentityUpsertBulk struct { + create *AuthIdentityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..4f1f6f3ce48d3a28fb14bd08fdcc11e4b4420576 --- /dev/null +++ b/backend/ent/authidentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityDelete is the builder for deleting a AuthIdentity entity. +type AuthIdentityDelete struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity. +type AuthIdentityDeleteOne struct { + _d *AuthIdentityDelete +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go new file mode 100644 index 0000000000000000000000000000000000000000..ff27ef3cd260d445c356f55e39e761055fd25ac0 --- /dev/null +++ b/backend/ent/authidentity_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityQuery is the builder for querying AuthIdentity entities. +type AuthIdentityQuery struct { + config + ctx *QueryContext + order []authidentity.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentity + withUser *UserQuery + withChannels *AuthIdentityChannelQuery + withAdoptionDecisions *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityQuery builder. +func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *AuthIdentityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChannels chains the current query on the "channels" edge. +func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge. +func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentity entity from the query. +// Returns a *NotFoundError when no AuthIdentity was found. +func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentity ID from the query. +// Returns a *NotFoundError when no AuthIdentity ID was found. +func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentity entity is found. +// Returns a *NotFoundError when no AuthIdentity entities are found. +func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentity.Label} + default: + return nil, &NotSingularError{authidentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentity ID in the query. +// Returns a *NotSingularError when more than one AuthIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentity.Label} + default: + err = &NotSingularError{authidentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentities. +func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]() + return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentity IDs. +func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery { + if _q == nil { + return nil + } + return &AuthIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentity{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withChannels: _q.withChannels.Clone(), + withAdoptionDecisions: _q.withAdoptionDecisions.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithChannels tells the query-builder to eager-load the nodes that are connected to +// the "channels" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withChannels = query + return _q +} + +// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecisions = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// GroupBy(authidentity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// Select(authidentity.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q} + sbuild.label = authidentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentitySelect configured with the given aggregations. +func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) { + var ( + nodes = []*AuthIdentity{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withUser != nil, + _q.withChannels != nil, + _q.withAdoptionDecisions != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentity{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withChannels; query != nil { + if err := _q.loadChannels(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} }, + func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecisions; query != nil { + if err := _q.loadAdoptionDecisions(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} }, + func(n *AuthIdentity, e *IdentityAdoptionDecision) { + n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentity) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID) + } + query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + if fk == nil { + return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for i := range fields { + if fields[i] != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(authidentity.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities. +type AuthIdentityGroupBy struct { + selector + build *AuthIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities. +type AuthIdentitySelect struct { + *AuthIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go new file mode 100644 index 0000000000000000000000000000000000000000..c457470b9b17b7bd03231239b19455e5e53d893a --- /dev/null +++ b/backend/ent/authidentity_update.go @@ -0,0 +1,923 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityUpdate is the builder for updating AuthIdentity entities. +type AuthIdentityUpdate struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity. +type AuthIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentity entity. +func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for _, f := range fields { + if !authidentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff3e5d1c88a1d5e596e7ee998c6d14eb1f34407 --- /dev/null +++ b/backend/ent/authidentitychannel.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema. +type AuthIdentityChannel struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID int64 `json:"identity_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // Channel holds the value of the "channel" field. + Channel string `json:"channel,omitempty"` + // ChannelAppID holds the value of the "channel_app_id" field. + ChannelAppID string `json:"channel_app_id,omitempty"` + // ChannelSubject holds the value of the "channel_subject" field. + ChannelSubject string `json:"channel_subject,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set. + Edges AuthIdentityChannelEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityChannelEdges struct { + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldMetadata: + values[i] = new([]byte) + case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID: + values[i] = new(sql.NullInt64) + case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject: + values[i] = new(sql.NullString) + case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentityChannel fields. +func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentitychannel.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentitychannel.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentitychannel.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = value.Int64 + } + case authidentitychannel.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentitychannel.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentitychannel.FieldChannel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel", values[i]) + } else if value.Valid { + _m.Channel = value.String + } + case authidentitychannel.FieldChannelAppID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_app_id", values[i]) + } else if value.Valid { + _m.ChannelAppID = value.String + } + case authidentitychannel.FieldChannelSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_subject", values[i]) + } else if value.Valid { + _m.ChannelSubject = value.String + } + case authidentitychannel.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity. +func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery { + return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this AuthIdentityChannel. +// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne { + return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentityChannel is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentityChannel) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentityChannel(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", _m.IdentityID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("channel=") + builder.WriteString(_m.Channel) + builder.WriteString(", ") + builder.WriteString("channel_app_id=") + builder.WriteString(_m.ChannelAppID) + builder.WriteString(", ") + builder.WriteString("channel_subject=") + builder.WriteString(_m.ChannelSubject) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentityChannels is a parsable slice of AuthIdentityChannel. +type AuthIdentityChannels []*AuthIdentityChannel diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go new file mode 100644 index 0000000000000000000000000000000000000000..7dcc98bb60b185f24fa7f97de257c4cb480ca811 --- /dev/null +++ b/backend/ent/authidentitychannel/authidentitychannel.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentitychannel type in the database. + Label = "auth_identity_channel" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldChannel holds the string denoting the channel field in the database. + FieldChannel = "channel" + // FieldChannelAppID holds the string denoting the channel_app_id field in the database. + FieldChannelAppID = "channel_app_id" + // FieldChannelSubject holds the string denoting the channel_subject field in the database. + FieldChannelSubject = "channel_subject" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the authidentitychannel in the database. + Table = "auth_identity_channels" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "auth_identity_channels" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentitychannel fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldIdentityID, + FieldProviderType, + FieldProviderKey, + FieldChannel, + FieldChannelAppID, + FieldChannelSubject, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + ChannelValidator func(string) error + // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + ChannelAppIDValidator func(string) error + // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + ChannelSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentityChannel queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByChannel orders the results by the channel field. +func ByChannel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannel, opts...).ToFunc() +} + +// ByChannelAppID orders the results by the channel_app_id field. +func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelAppID, opts...).ToFunc() +} + +// ByChannelSubject orders the results by the channel_subject field. +func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelSubject, opts...).ToFunc() +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go new file mode 100644 index 0000000000000000000000000000000000000000..827dc38450ede777c9442f4855c6c234ad03f1c2 --- /dev/null +++ b/backend/ent/authidentitychannel/where.go @@ -0,0 +1,559 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ. +func Channel(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ. +func ChannelAppID(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ. +func ChannelSubject(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ChannelEQ applies the EQ predicate on the "channel" field. +func ChannelEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelNEQ applies the NEQ predicate on the "channel" field. +func ChannelNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v)) +} + +// ChannelIn applies the In predicate on the "channel" field. +func ChannelIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...)) +} + +// ChannelNotIn applies the NotIn predicate on the "channel" field. +func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...)) +} + +// ChannelGT applies the GT predicate on the "channel" field. +func ChannelGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v)) +} + +// ChannelGTE applies the GTE predicate on the "channel" field. +func ChannelGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v)) +} + +// ChannelLT applies the LT predicate on the "channel" field. +func ChannelLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v)) +} + +// ChannelLTE applies the LTE predicate on the "channel" field. +func ChannelLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v)) +} + +// ChannelContains applies the Contains predicate on the "channel" field. +func ChannelContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v)) +} + +// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field. +func ChannelHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v)) +} + +// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field. +func ChannelHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v)) +} + +// ChannelEqualFold applies the EqualFold predicate on the "channel" field. +func ChannelEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v)) +} + +// ChannelContainsFold applies the ContainsFold predicate on the "channel" field. +func ChannelContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v)) +} + +// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field. +func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field. +func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDIn applies the In predicate on the "channel_app_id" field. +func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field. +func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field. +func ChannelAppIDGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v)) +} + +// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field. +func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v)) +} + +// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field. +func ChannelAppIDLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v)) +} + +// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field. +func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v)) +} + +// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field. +func ChannelAppIDContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v)) +} + +// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field. +func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v)) +} + +// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field. +func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v)) +} + +// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field. +func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v)) +} + +// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field. +func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v)) +} + +// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field. +func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field. +func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectIn applies the In predicate on the "channel_subject" field. +func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field. +func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectGT applies the GT predicate on the "channel_subject" field. +func ChannelSubjectGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v)) +} + +// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field. +func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v)) +} + +// ChannelSubjectLT applies the LT predicate on the "channel_subject" field. +func ChannelSubjectLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v)) +} + +// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field. +func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v)) +} + +// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field. +func ChannelSubjectContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v)) +} + +// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field. +func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v)) +} + +// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field. +func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v)) +} + +// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field. +func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v)) +} + +// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field. +func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v)) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go new file mode 100644 index 0000000000000000000000000000000000000000..4ce284792b16cdad35c6817a6b256cd0ec366be5 --- /dev/null +++ b/backend/ent/authidentitychannel_create.go @@ -0,0 +1,932 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity. +type AuthIdentityChannelCreate struct { + config + mutation *AuthIdentityChannelMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetChannel sets the "channel" field. +func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannel(v) + return _c +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelAppID(v) + return _c +} + +// SetChannelSubject sets the "channel_subject" field. +func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelSubject(v) + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation { + return _c.mutation +} + +// Save creates the AuthIdentityChannel in the database. +func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityChannelCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentitychannel.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentitychannel.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentitychannel.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityChannelCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)} + } + if _, ok := _c.mutation.IdentityID(); !ok { + return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.Channel(); !ok { + return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)} + } + if v, ok := _c.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelAppID(); !ok { + return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)} + } + if v, ok := _c.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelSubject(); !ok { + return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)} + } + if v, ok := _c.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)} + } + if len(_c.mutation.IdentityIDs()) == 0 { + return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)} + } + return nil +} + +func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentityChannel{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + _node.Channel = value + } + if value, ok := _c.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + _node.ChannelAppID = value + } + if value, ok := _c.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + _node.ChannelSubject = value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne { + _c.conflict = opts + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing + // one AuthIdentityChannel node. + AuthIdentityChannelUpsertOne struct { + create *AuthIdentityChannelCreate + } + + // AuthIdentityChannelUpsert is the "OnConflict" setter. + AuthIdentityChannelUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldUpdatedAt) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldIdentityID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderKey) + return u +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannel, v) + return u +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannel) + return u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelAppID, v) + return u +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelAppID) + return u +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelSubject, v) + return u +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelSubject) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk. +type AuthIdentityChannelCreateBulk struct { + config + err error + builders []*AuthIdentityChannelCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentityChannel entities in the database. +func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentityChannel, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityChannelMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk { + _c.conflict = opts + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentityChannel nodes. +type AuthIdentityChannelUpsertBulk struct { + create *AuthIdentityChannelCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..1a4acac59063fad5d80c787dec54c431928672e8 --- /dev/null +++ b/backend/ent/authidentitychannel_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity. +type AuthIdentityChannelDelete struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity. +type AuthIdentityChannelDeleteOne struct { + _d *AuthIdentityChannelDelete +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentitychannel.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go new file mode 100644 index 0000000000000000000000000000000000000000..7a202b7f1fd9573923b6f29d7bc18e0373efacd4 --- /dev/null +++ b/backend/ent/authidentitychannel_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities. +type AuthIdentityChannelQuery struct { + config + ctx *QueryContext + order []authidentitychannel.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentityChannel + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityChannelQuery builder. +func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentityChannel entity from the query. +// Returns a *NotFoundError when no AuthIdentityChannel was found. +func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentitychannel.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentityChannel ID from the query. +// Returns a *NotFoundError when no AuthIdentityChannel ID was found. +func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentitychannel.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found. +// Returns a *NotFoundError when no AuthIdentityChannel entities are found. +func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentitychannel.Label} + default: + return nil, &NotSingularError{authidentitychannel.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query. +// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentitychannel.Label} + default: + err = &NotSingularError{authidentitychannel.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentityChannels. +func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]() + return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentityChannel IDs. +func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery { + if _q == nil { + return nil + } + return &AuthIdentityChannelQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentitychannel.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// GroupBy(authidentitychannel.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityChannelGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentitychannel.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// Select(authidentitychannel.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q} + sbuild.label = authidentitychannel.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations. +func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentitychannel.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) { + var ( + nodes = []*AuthIdentityChannel{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentityChannel).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentityChannel{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentityChannel) + for i := range nodes { + fk := nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for i := range fields { + if fields[i] != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentitychannel.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentitychannel.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities. +type AuthIdentityChannelGroupBy struct { + selector + build *AuthIdentityChannelQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities. +type AuthIdentityChannelSelect struct { + *AuthIdentityChannelQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go new file mode 100644 index 0000000000000000000000000000000000000000..b550c4545fdf8187dff66fd9dc574920270dde9a --- /dev/null +++ b/backend/ent/authidentitychannel_update.go @@ -0,0 +1,581 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities. +type AuthIdentityChannelUpdate struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity. +type AuthIdentityChannelUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentityChannel entity. +func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for _, f := range fields { + if !authidentitychannel.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentityChannel{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go new file mode 100644 index 0000000000000000000000000000000000000000..dbb733624d501270a381977276690a6536ec7ab1 --- /dev/null +++ b/backend/ent/channelmonitor.go @@ -0,0 +1,359 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" +) + +// ChannelMonitor is the model entity for the ChannelMonitor schema. +type ChannelMonitor struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Provider holds the value of the "provider" field. + Provider channelmonitor.Provider `json:"provider,omitempty"` + // Provider base origin, e.g. https://api.openai.com + Endpoint string `json:"endpoint,omitempty"` + // AES-256-GCM encrypted API key + APIKeyEncrypted string `json:"-"` + // PrimaryModel holds the value of the "primary_model" field. + PrimaryModel string `json:"primary_model,omitempty"` + // Additional model names to test alongside primary_model + ExtraModels []string `json:"extra_models,omitempty"` + // GroupName holds the value of the "group_name" field. + GroupName string `json:"group_name,omitempty"` + // Enabled holds the value of the "enabled" field. + Enabled bool `json:"enabled,omitempty"` + // IntervalSeconds holds the value of the "interval_seconds" field. + IntervalSeconds int `json:"interval_seconds,omitempty"` + // LastCheckedAt holds the value of the "last_checked_at" field. + LastCheckedAt *time.Time `json:"last_checked_at,omitempty"` + // CreatedBy holds the value of the "created_by" field. + CreatedBy int64 `json:"created_by,omitempty"` + // TemplateID holds the value of the "template_id" field. + TemplateID *int64 `json:"template_id,omitempty"` + // ExtraHeaders holds the value of the "extra_headers" field. + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // BodyOverrideMode holds the value of the "body_override_mode" field. + BodyOverrideMode string `json:"body_override_mode,omitempty"` + // BodyOverride holds the value of the "body_override" field. + BodyOverride map[string]interface{} `json:"body_override,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ChannelMonitorQuery when eager-loading is set. + Edges ChannelMonitorEdges `json:"edges"` + selectValues sql.SelectValues +} + +// ChannelMonitorEdges holds the relations/edges for other nodes in the graph. +type ChannelMonitorEdges struct { + // History holds the value of the history edge. + History []*ChannelMonitorHistory `json:"history,omitempty"` + // DailyRollups holds the value of the daily_rollups edge. + DailyRollups []*ChannelMonitorDailyRollup `json:"daily_rollups,omitempty"` + // RequestTemplate holds the value of the request_template edge. + RequestTemplate *ChannelMonitorRequestTemplate `json:"request_template,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// HistoryOrErr returns the History value or an error if the edge +// was not loaded in eager-loading. +func (e ChannelMonitorEdges) HistoryOrErr() ([]*ChannelMonitorHistory, error) { + if e.loadedTypes[0] { + return e.History, nil + } + return nil, &NotLoadedError{edge: "history"} +} + +// DailyRollupsOrErr returns the DailyRollups value or an error if the edge +// was not loaded in eager-loading. +func (e ChannelMonitorEdges) DailyRollupsOrErr() ([]*ChannelMonitorDailyRollup, error) { + if e.loadedTypes[1] { + return e.DailyRollups, nil + } + return nil, &NotLoadedError{edge: "daily_rollups"} +} + +// RequestTemplateOrErr returns the RequestTemplate value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e ChannelMonitorEdges) RequestTemplateOrErr() (*ChannelMonitorRequestTemplate, error) { + if e.RequestTemplate != nil { + return e.RequestTemplate, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: channelmonitorrequesttemplate.Label} + } + return nil, &NotLoadedError{edge: "request_template"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ChannelMonitor) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case channelmonitor.FieldExtraModels, channelmonitor.FieldExtraHeaders, channelmonitor.FieldBodyOverride: + values[i] = new([]byte) + case channelmonitor.FieldEnabled: + values[i] = new(sql.NullBool) + case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID: + values[i] = new(sql.NullInt64) + case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode: + values[i] = new(sql.NullString) + case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ChannelMonitor fields. +func (_m *ChannelMonitor) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case channelmonitor.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case channelmonitor.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case channelmonitor.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case channelmonitor.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case channelmonitor.FieldProvider: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider", values[i]) + } else if value.Valid { + _m.Provider = channelmonitor.Provider(value.String) + } + case channelmonitor.FieldEndpoint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field endpoint", values[i]) + } else if value.Valid { + _m.Endpoint = value.String + } + case channelmonitor.FieldAPIKeyEncrypted: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field api_key_encrypted", values[i]) + } else if value.Valid { + _m.APIKeyEncrypted = value.String + } + case channelmonitor.FieldPrimaryModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field primary_model", values[i]) + } else if value.Valid { + _m.PrimaryModel = value.String + } + case channelmonitor.FieldExtraModels: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field extra_models", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ExtraModels); err != nil { + return fmt.Errorf("unmarshal field extra_models: %w", err) + } + } + case channelmonitor.FieldGroupName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field group_name", values[i]) + } else if value.Valid { + _m.GroupName = value.String + } + case channelmonitor.FieldEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field enabled", values[i]) + } else if value.Valid { + _m.Enabled = value.Bool + } + case channelmonitor.FieldIntervalSeconds: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field interval_seconds", values[i]) + } else if value.Valid { + _m.IntervalSeconds = int(value.Int64) + } + case channelmonitor.FieldLastCheckedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_checked_at", values[i]) + } else if value.Valid { + _m.LastCheckedAt = new(time.Time) + *_m.LastCheckedAt = value.Time + } + case channelmonitor.FieldCreatedBy: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field created_by", values[i]) + } else if value.Valid { + _m.CreatedBy = value.Int64 + } + case channelmonitor.FieldTemplateID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field template_id", values[i]) + } else if value.Valid { + _m.TemplateID = new(int64) + *_m.TemplateID = value.Int64 + } + case channelmonitor.FieldExtraHeaders: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field extra_headers", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil { + return fmt.Errorf("unmarshal field extra_headers: %w", err) + } + } + case channelmonitor.FieldBodyOverrideMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field body_override_mode", values[i]) + } else if value.Valid { + _m.BodyOverrideMode = value.String + } + case channelmonitor.FieldBodyOverride: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field body_override", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil { + return fmt.Errorf("unmarshal field body_override: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitor. +// This includes values selected through modifiers, order, etc. +func (_m *ChannelMonitor) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryHistory queries the "history" edge of the ChannelMonitor entity. +func (_m *ChannelMonitor) QueryHistory() *ChannelMonitorHistoryQuery { + return NewChannelMonitorClient(_m.config).QueryHistory(_m) +} + +// QueryDailyRollups queries the "daily_rollups" edge of the ChannelMonitor entity. +func (_m *ChannelMonitor) QueryDailyRollups() *ChannelMonitorDailyRollupQuery { + return NewChannelMonitorClient(_m.config).QueryDailyRollups(_m) +} + +// QueryRequestTemplate queries the "request_template" edge of the ChannelMonitor entity. +func (_m *ChannelMonitor) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery { + return NewChannelMonitorClient(_m.config).QueryRequestTemplate(_m) +} + +// Update returns a builder for updating this ChannelMonitor. +// Note that you need to call ChannelMonitor.Unwrap() before calling this method if this ChannelMonitor +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ChannelMonitor) Update() *ChannelMonitorUpdateOne { + return NewChannelMonitorClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ChannelMonitor entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ChannelMonitor) Unwrap() *ChannelMonitor { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ChannelMonitor is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ChannelMonitor) String() string { + var builder strings.Builder + builder.WriteString("ChannelMonitor(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + builder.WriteString("provider=") + builder.WriteString(fmt.Sprintf("%v", _m.Provider)) + builder.WriteString(", ") + builder.WriteString("endpoint=") + builder.WriteString(_m.Endpoint) + builder.WriteString(", ") + builder.WriteString("api_key_encrypted=") + builder.WriteString(", ") + builder.WriteString("primary_model=") + builder.WriteString(_m.PrimaryModel) + builder.WriteString(", ") + builder.WriteString("extra_models=") + builder.WriteString(fmt.Sprintf("%v", _m.ExtraModels)) + builder.WriteString(", ") + builder.WriteString("group_name=") + builder.WriteString(_m.GroupName) + builder.WriteString(", ") + builder.WriteString("enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.Enabled)) + builder.WriteString(", ") + builder.WriteString("interval_seconds=") + builder.WriteString(fmt.Sprintf("%v", _m.IntervalSeconds)) + builder.WriteString(", ") + if v := _m.LastCheckedAt; v != nil { + builder.WriteString("last_checked_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("created_by=") + builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy)) + builder.WriteString(", ") + if v := _m.TemplateID; v != nil { + builder.WriteString("template_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("extra_headers=") + builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders)) + builder.WriteString(", ") + builder.WriteString("body_override_mode=") + builder.WriteString(_m.BodyOverrideMode) + builder.WriteString(", ") + builder.WriteString("body_override=") + builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride)) + builder.WriteByte(')') + return builder.String() +} + +// ChannelMonitors is a parsable slice of ChannelMonitor. +type ChannelMonitors []*ChannelMonitor diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go new file mode 100644 index 0000000000000000000000000000000000000000..e5a6bfe70af1dd0e1d66fdd05a788a05a5c144ca --- /dev/null +++ b/backend/ent/channelmonitor/channelmonitor.go @@ -0,0 +1,304 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitor + +import ( + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the channelmonitor type in the database. + Label = "channel_monitor" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldProvider holds the string denoting the provider field in the database. + FieldProvider = "provider" + // FieldEndpoint holds the string denoting the endpoint field in the database. + FieldEndpoint = "endpoint" + // FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database. + FieldAPIKeyEncrypted = "api_key_encrypted" + // FieldPrimaryModel holds the string denoting the primary_model field in the database. + FieldPrimaryModel = "primary_model" + // FieldExtraModels holds the string denoting the extra_models field in the database. + FieldExtraModels = "extra_models" + // FieldGroupName holds the string denoting the group_name field in the database. + FieldGroupName = "group_name" + // FieldEnabled holds the string denoting the enabled field in the database. + FieldEnabled = "enabled" + // FieldIntervalSeconds holds the string denoting the interval_seconds field in the database. + FieldIntervalSeconds = "interval_seconds" + // FieldLastCheckedAt holds the string denoting the last_checked_at field in the database. + FieldLastCheckedAt = "last_checked_at" + // FieldCreatedBy holds the string denoting the created_by field in the database. + FieldCreatedBy = "created_by" + // FieldTemplateID holds the string denoting the template_id field in the database. + FieldTemplateID = "template_id" + // FieldExtraHeaders holds the string denoting the extra_headers field in the database. + FieldExtraHeaders = "extra_headers" + // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database. + FieldBodyOverrideMode = "body_override_mode" + // FieldBodyOverride holds the string denoting the body_override field in the database. + FieldBodyOverride = "body_override" + // EdgeHistory holds the string denoting the history edge name in mutations. + EdgeHistory = "history" + // EdgeDailyRollups holds the string denoting the daily_rollups edge name in mutations. + EdgeDailyRollups = "daily_rollups" + // EdgeRequestTemplate holds the string denoting the request_template edge name in mutations. + EdgeRequestTemplate = "request_template" + // Table holds the table name of the channelmonitor in the database. + Table = "channel_monitors" + // HistoryTable is the table that holds the history relation/edge. + HistoryTable = "channel_monitor_histories" + // HistoryInverseTable is the table name for the ChannelMonitorHistory entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitorhistory" package. + HistoryInverseTable = "channel_monitor_histories" + // HistoryColumn is the table column denoting the history relation/edge. + HistoryColumn = "monitor_id" + // DailyRollupsTable is the table that holds the daily_rollups relation/edge. + DailyRollupsTable = "channel_monitor_daily_rollups" + // DailyRollupsInverseTable is the table name for the ChannelMonitorDailyRollup entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitordailyrollup" package. + DailyRollupsInverseTable = "channel_monitor_daily_rollups" + // DailyRollupsColumn is the table column denoting the daily_rollups relation/edge. + DailyRollupsColumn = "monitor_id" + // RequestTemplateTable is the table that holds the request_template relation/edge. + RequestTemplateTable = "channel_monitors" + // RequestTemplateInverseTable is the table name for the ChannelMonitorRequestTemplate entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitorrequesttemplate" package. + RequestTemplateInverseTable = "channel_monitor_request_templates" + // RequestTemplateColumn is the table column denoting the request_template relation/edge. + RequestTemplateColumn = "template_id" +) + +// Columns holds all SQL columns for channelmonitor fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldProvider, + FieldEndpoint, + FieldAPIKeyEncrypted, + FieldPrimaryModel, + FieldExtraModels, + FieldGroupName, + FieldEnabled, + FieldIntervalSeconds, + FieldLastCheckedAt, + FieldCreatedBy, + FieldTemplateID, + FieldExtraHeaders, + FieldBodyOverrideMode, + FieldBodyOverride, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save. + EndpointValidator func(string) error + // APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save. + APIKeyEncryptedValidator func(string) error + // PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save. + PrimaryModelValidator func(string) error + // DefaultExtraModels holds the default value on creation for the "extra_models" field. + DefaultExtraModels []string + // DefaultGroupName holds the default value on creation for the "group_name" field. + DefaultGroupName string + // GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save. + GroupNameValidator func(string) error + // DefaultEnabled holds the default value on creation for the "enabled" field. + DefaultEnabled bool + // IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save. + IntervalSecondsValidator func(int) error + // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field. + DefaultExtraHeaders map[string]string + // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field. + DefaultBodyOverrideMode string + // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save. + BodyOverrideModeValidator func(string) error +) + +// Provider defines the type for the "provider" enum field. +type Provider string + +// Provider values. +const ( + ProviderOpenai Provider = "openai" + ProviderAnthropic Provider = "anthropic" + ProviderGemini Provider = "gemini" +) + +func (pr Provider) String() string { + return string(pr) +} + +// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save. +func ProviderValidator(pr Provider) error { + switch pr { + case ProviderOpenai, ProviderAnthropic, ProviderGemini: + return nil + default: + return fmt.Errorf("channelmonitor: invalid enum value for provider field: %q", pr) + } +} + +// OrderOption defines the ordering options for the ChannelMonitor queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByProvider orders the results by the provider field. +func ByProvider(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProvider, opts...).ToFunc() +} + +// ByEndpoint orders the results by the endpoint field. +func ByEndpoint(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEndpoint, opts...).ToFunc() +} + +// ByAPIKeyEncrypted orders the results by the api_key_encrypted field. +func ByAPIKeyEncrypted(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKeyEncrypted, opts...).ToFunc() +} + +// ByPrimaryModel orders the results by the primary_model field. +func ByPrimaryModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrimaryModel, opts...).ToFunc() +} + +// ByGroupName orders the results by the group_name field. +func ByGroupName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGroupName, opts...).ToFunc() +} + +// ByEnabled orders the results by the enabled field. +func ByEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEnabled, opts...).ToFunc() +} + +// ByIntervalSeconds orders the results by the interval_seconds field. +func ByIntervalSeconds(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntervalSeconds, opts...).ToFunc() +} + +// ByLastCheckedAt orders the results by the last_checked_at field. +func ByLastCheckedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastCheckedAt, opts...).ToFunc() +} + +// ByCreatedBy orders the results by the created_by field. +func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedBy, opts...).ToFunc() +} + +// ByTemplateID orders the results by the template_id field. +func ByTemplateID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTemplateID, opts...).ToFunc() +} + +// ByBodyOverrideMode orders the results by the body_override_mode field. +func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc() +} + +// ByHistoryCount orders the results by history count. +func ByHistoryCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newHistoryStep(), opts...) + } +} + +// ByHistory orders the results by history terms. +func ByHistory(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newHistoryStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByDailyRollupsCount orders the results by daily_rollups count. +func ByDailyRollupsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDailyRollupsStep(), opts...) + } +} + +// ByDailyRollups orders the results by daily_rollups terms. +func ByDailyRollups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDailyRollupsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByRequestTemplateField orders the results by request_template field. +func ByRequestTemplateField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newRequestTemplateStep(), sql.OrderByField(field, opts...)) + } +} +func newHistoryStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(HistoryInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn), + ) +} +func newDailyRollupsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DailyRollupsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn), + ) +} +func newRequestTemplateStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(RequestTemplateInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn), + ) +} diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go new file mode 100644 index 0000000000000000000000000000000000000000..755d83a3ff13500d4c3988a78226a7c92ced14a4 --- /dev/null +++ b/backend/ent/channelmonitor/where.go @@ -0,0 +1,885 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitor + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v)) +} + +// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ. +func Endpoint(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v)) +} + +// APIKeyEncrypted applies equality check predicate on the "api_key_encrypted" field. It's identical to APIKeyEncryptedEQ. +func APIKeyEncrypted(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v)) +} + +// PrimaryModel applies equality check predicate on the "primary_model" field. It's identical to PrimaryModelEQ. +func PrimaryModel(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v)) +} + +// GroupName applies equality check predicate on the "group_name" field. It's identical to GroupNameEQ. +func GroupName(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v)) +} + +// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ. +func Enabled(v bool) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v)) +} + +// IntervalSeconds applies equality check predicate on the "interval_seconds" field. It's identical to IntervalSecondsEQ. +func IntervalSeconds(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v)) +} + +// LastCheckedAt applies equality check predicate on the "last_checked_at" field. It's identical to LastCheckedAtEQ. +func LastCheckedAt(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v)) +} + +// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ. +func CreatedBy(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v)) +} + +// TemplateID applies equality check predicate on the "template_id" field. It's identical to TemplateIDEQ. +func TemplateID(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v)) +} + +// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ. +func BodyOverrideMode(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldName, v)) +} + +// ProviderEQ applies the EQ predicate on the "provider" field. +func ProviderEQ(v Provider) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldProvider, v)) +} + +// ProviderNEQ applies the NEQ predicate on the "provider" field. +func ProviderNEQ(v Provider) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldProvider, v)) +} + +// ProviderIn applies the In predicate on the "provider" field. +func ProviderIn(vs ...Provider) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldProvider, vs...)) +} + +// ProviderNotIn applies the NotIn predicate on the "provider" field. +func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...)) +} + +// EndpointEQ applies the EQ predicate on the "endpoint" field. +func EndpointEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v)) +} + +// EndpointNEQ applies the NEQ predicate on the "endpoint" field. +func EndpointNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldEndpoint, v)) +} + +// EndpointIn applies the In predicate on the "endpoint" field. +func EndpointIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldEndpoint, vs...)) +} + +// EndpointNotIn applies the NotIn predicate on the "endpoint" field. +func EndpointNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldEndpoint, vs...)) +} + +// EndpointGT applies the GT predicate on the "endpoint" field. +func EndpointGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldEndpoint, v)) +} + +// EndpointGTE applies the GTE predicate on the "endpoint" field. +func EndpointGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldEndpoint, v)) +} + +// EndpointLT applies the LT predicate on the "endpoint" field. +func EndpointLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldEndpoint, v)) +} + +// EndpointLTE applies the LTE predicate on the "endpoint" field. +func EndpointLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldEndpoint, v)) +} + +// EndpointContains applies the Contains predicate on the "endpoint" field. +func EndpointContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldEndpoint, v)) +} + +// EndpointHasPrefix applies the HasPrefix predicate on the "endpoint" field. +func EndpointHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldEndpoint, v)) +} + +// EndpointHasSuffix applies the HasSuffix predicate on the "endpoint" field. +func EndpointHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldEndpoint, v)) +} + +// EndpointEqualFold applies the EqualFold predicate on the "endpoint" field. +func EndpointEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldEndpoint, v)) +} + +// EndpointContainsFold applies the ContainsFold predicate on the "endpoint" field. +func EndpointContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldEndpoint, v)) +} + +// APIKeyEncryptedEQ applies the EQ predicate on the "api_key_encrypted" field. +func APIKeyEncryptedEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedNEQ applies the NEQ predicate on the "api_key_encrypted" field. +func APIKeyEncryptedNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedIn applies the In predicate on the "api_key_encrypted" field. +func APIKeyEncryptedIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldAPIKeyEncrypted, vs...)) +} + +// APIKeyEncryptedNotIn applies the NotIn predicate on the "api_key_encrypted" field. +func APIKeyEncryptedNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIKeyEncrypted, vs...)) +} + +// APIKeyEncryptedGT applies the GT predicate on the "api_key_encrypted" field. +func APIKeyEncryptedGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedGTE applies the GTE predicate on the "api_key_encrypted" field. +func APIKeyEncryptedGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedLT applies the LT predicate on the "api_key_encrypted" field. +func APIKeyEncryptedLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedLTE applies the LTE predicate on the "api_key_encrypted" field. +func APIKeyEncryptedLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedContains applies the Contains predicate on the "api_key_encrypted" field. +func APIKeyEncryptedContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedHasPrefix applies the HasPrefix predicate on the "api_key_encrypted" field. +func APIKeyEncryptedHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedHasSuffix applies the HasSuffix predicate on the "api_key_encrypted" field. +func APIKeyEncryptedHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedEqualFold applies the EqualFold predicate on the "api_key_encrypted" field. +func APIKeyEncryptedEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIKeyEncrypted, v)) +} + +// APIKeyEncryptedContainsFold applies the ContainsFold predicate on the "api_key_encrypted" field. +func APIKeyEncryptedContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIKeyEncrypted, v)) +} + +// PrimaryModelEQ applies the EQ predicate on the "primary_model" field. +func PrimaryModelEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v)) +} + +// PrimaryModelNEQ applies the NEQ predicate on the "primary_model" field. +func PrimaryModelNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldPrimaryModel, v)) +} + +// PrimaryModelIn applies the In predicate on the "primary_model" field. +func PrimaryModelIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldPrimaryModel, vs...)) +} + +// PrimaryModelNotIn applies the NotIn predicate on the "primary_model" field. +func PrimaryModelNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldPrimaryModel, vs...)) +} + +// PrimaryModelGT applies the GT predicate on the "primary_model" field. +func PrimaryModelGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldPrimaryModel, v)) +} + +// PrimaryModelGTE applies the GTE predicate on the "primary_model" field. +func PrimaryModelGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldPrimaryModel, v)) +} + +// PrimaryModelLT applies the LT predicate on the "primary_model" field. +func PrimaryModelLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldPrimaryModel, v)) +} + +// PrimaryModelLTE applies the LTE predicate on the "primary_model" field. +func PrimaryModelLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldPrimaryModel, v)) +} + +// PrimaryModelContains applies the Contains predicate on the "primary_model" field. +func PrimaryModelContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldPrimaryModel, v)) +} + +// PrimaryModelHasPrefix applies the HasPrefix predicate on the "primary_model" field. +func PrimaryModelHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldPrimaryModel, v)) +} + +// PrimaryModelHasSuffix applies the HasSuffix predicate on the "primary_model" field. +func PrimaryModelHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldPrimaryModel, v)) +} + +// PrimaryModelEqualFold applies the EqualFold predicate on the "primary_model" field. +func PrimaryModelEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldPrimaryModel, v)) +} + +// PrimaryModelContainsFold applies the ContainsFold predicate on the "primary_model" field. +func PrimaryModelContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldPrimaryModel, v)) +} + +// GroupNameEQ applies the EQ predicate on the "group_name" field. +func GroupNameEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v)) +} + +// GroupNameNEQ applies the NEQ predicate on the "group_name" field. +func GroupNameNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldGroupName, v)) +} + +// GroupNameIn applies the In predicate on the "group_name" field. +func GroupNameIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldGroupName, vs...)) +} + +// GroupNameNotIn applies the NotIn predicate on the "group_name" field. +func GroupNameNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldGroupName, vs...)) +} + +// GroupNameGT applies the GT predicate on the "group_name" field. +func GroupNameGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldGroupName, v)) +} + +// GroupNameGTE applies the GTE predicate on the "group_name" field. +func GroupNameGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldGroupName, v)) +} + +// GroupNameLT applies the LT predicate on the "group_name" field. +func GroupNameLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldGroupName, v)) +} + +// GroupNameLTE applies the LTE predicate on the "group_name" field. +func GroupNameLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldGroupName, v)) +} + +// GroupNameContains applies the Contains predicate on the "group_name" field. +func GroupNameContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldGroupName, v)) +} + +// GroupNameHasPrefix applies the HasPrefix predicate on the "group_name" field. +func GroupNameHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldGroupName, v)) +} + +// GroupNameHasSuffix applies the HasSuffix predicate on the "group_name" field. +func GroupNameHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldGroupName, v)) +} + +// GroupNameIsNil applies the IsNil predicate on the "group_name" field. +func GroupNameIsNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIsNull(FieldGroupName)) +} + +// GroupNameNotNil applies the NotNil predicate on the "group_name" field. +func GroupNameNotNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotNull(FieldGroupName)) +} + +// GroupNameEqualFold applies the EqualFold predicate on the "group_name" field. +func GroupNameEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldGroupName, v)) +} + +// GroupNameContainsFold applies the ContainsFold predicate on the "group_name" field. +func GroupNameContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldGroupName, v)) +} + +// EnabledEQ applies the EQ predicate on the "enabled" field. +func EnabledEQ(v bool) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v)) +} + +// EnabledNEQ applies the NEQ predicate on the "enabled" field. +func EnabledNEQ(v bool) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldEnabled, v)) +} + +// IntervalSecondsEQ applies the EQ predicate on the "interval_seconds" field. +func IntervalSecondsEQ(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v)) +} + +// IntervalSecondsNEQ applies the NEQ predicate on the "interval_seconds" field. +func IntervalSecondsNEQ(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldIntervalSeconds, v)) +} + +// IntervalSecondsIn applies the In predicate on the "interval_seconds" field. +func IntervalSecondsIn(vs ...int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldIntervalSeconds, vs...)) +} + +// IntervalSecondsNotIn applies the NotIn predicate on the "interval_seconds" field. +func IntervalSecondsNotIn(vs ...int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldIntervalSeconds, vs...)) +} + +// IntervalSecondsGT applies the GT predicate on the "interval_seconds" field. +func IntervalSecondsGT(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldIntervalSeconds, v)) +} + +// IntervalSecondsGTE applies the GTE predicate on the "interval_seconds" field. +func IntervalSecondsGTE(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldIntervalSeconds, v)) +} + +// IntervalSecondsLT applies the LT predicate on the "interval_seconds" field. +func IntervalSecondsLT(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldIntervalSeconds, v)) +} + +// IntervalSecondsLTE applies the LTE predicate on the "interval_seconds" field. +func IntervalSecondsLTE(v int) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldIntervalSeconds, v)) +} + +// LastCheckedAtEQ applies the EQ predicate on the "last_checked_at" field. +func LastCheckedAtEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v)) +} + +// LastCheckedAtNEQ applies the NEQ predicate on the "last_checked_at" field. +func LastCheckedAtNEQ(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldLastCheckedAt, v)) +} + +// LastCheckedAtIn applies the In predicate on the "last_checked_at" field. +func LastCheckedAtIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldLastCheckedAt, vs...)) +} + +// LastCheckedAtNotIn applies the NotIn predicate on the "last_checked_at" field. +func LastCheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldLastCheckedAt, vs...)) +} + +// LastCheckedAtGT applies the GT predicate on the "last_checked_at" field. +func LastCheckedAtGT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldLastCheckedAt, v)) +} + +// LastCheckedAtGTE applies the GTE predicate on the "last_checked_at" field. +func LastCheckedAtGTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldLastCheckedAt, v)) +} + +// LastCheckedAtLT applies the LT predicate on the "last_checked_at" field. +func LastCheckedAtLT(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldLastCheckedAt, v)) +} + +// LastCheckedAtLTE applies the LTE predicate on the "last_checked_at" field. +func LastCheckedAtLTE(v time.Time) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldLastCheckedAt, v)) +} + +// LastCheckedAtIsNil applies the IsNil predicate on the "last_checked_at" field. +func LastCheckedAtIsNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIsNull(FieldLastCheckedAt)) +} + +// LastCheckedAtNotNil applies the NotNil predicate on the "last_checked_at" field. +func LastCheckedAtNotNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotNull(FieldLastCheckedAt)) +} + +// CreatedByEQ applies the EQ predicate on the "created_by" field. +func CreatedByEQ(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v)) +} + +// CreatedByNEQ applies the NEQ predicate on the "created_by" field. +func CreatedByNEQ(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedBy, v)) +} + +// CreatedByIn applies the In predicate on the "created_by" field. +func CreatedByIn(vs ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedBy, vs...)) +} + +// CreatedByNotIn applies the NotIn predicate on the "created_by" field. +func CreatedByNotIn(vs ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedBy, vs...)) +} + +// CreatedByGT applies the GT predicate on the "created_by" field. +func CreatedByGT(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedBy, v)) +} + +// CreatedByGTE applies the GTE predicate on the "created_by" field. +func CreatedByGTE(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedBy, v)) +} + +// CreatedByLT applies the LT predicate on the "created_by" field. +func CreatedByLT(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedBy, v)) +} + +// CreatedByLTE applies the LTE predicate on the "created_by" field. +func CreatedByLTE(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedBy, v)) +} + +// TemplateIDEQ applies the EQ predicate on the "template_id" field. +func TemplateIDEQ(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v)) +} + +// TemplateIDNEQ applies the NEQ predicate on the "template_id" field. +func TemplateIDNEQ(v int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldTemplateID, v)) +} + +// TemplateIDIn applies the In predicate on the "template_id" field. +func TemplateIDIn(vs ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldTemplateID, vs...)) +} + +// TemplateIDNotIn applies the NotIn predicate on the "template_id" field. +func TemplateIDNotIn(vs ...int64) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldTemplateID, vs...)) +} + +// TemplateIDIsNil applies the IsNil predicate on the "template_id" field. +func TemplateIDIsNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIsNull(FieldTemplateID)) +} + +// TemplateIDNotNil applies the NotNil predicate on the "template_id" field. +func TemplateIDNotNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotNull(FieldTemplateID)) +} + +// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field. +func BodyOverrideModeEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field. +func BodyOverrideModeNEQ(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNEQ(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field. +func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIn(FieldBodyOverrideMode, vs...)) +} + +// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field. +func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotIn(FieldBodyOverrideMode, vs...)) +} + +// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field. +func BodyOverrideModeGT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGT(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field. +func BodyOverrideModeGTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldGTE(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field. +func BodyOverrideModeLT(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLT(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field. +func BodyOverrideModeLTE(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldLTE(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field. +func BodyOverrideModeContains(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContains(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field. +func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field. +func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field. +func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldEqualFold(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field. +func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldContainsFold(FieldBodyOverrideMode, v)) +} + +// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field. +func BodyOverrideIsNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldIsNull(FieldBodyOverride)) +} + +// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field. +func BodyOverrideNotNil() predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.FieldNotNull(FieldBodyOverride)) +} + +// HasHistory applies the HasEdge predicate on the "history" edge. +func HasHistory() predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasHistoryWith applies the HasEdge predicate on the "history" edge with a given conditions (other predicates). +func HasHistoryWith(preds ...predicate.ChannelMonitorHistory) predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := newHistoryStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasDailyRollups applies the HasEdge predicate on the "daily_rollups" edge. +func HasDailyRollups() predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasDailyRollupsWith applies the HasEdge predicate on the "daily_rollups" edge with a given conditions (other predicates). +func HasDailyRollupsWith(preds ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := newDailyRollupsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasRequestTemplate applies the HasEdge predicate on the "request_template" edge. +func HasRequestTemplate() predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasRequestTemplateWith applies the HasEdge predicate on the "request_template" edge with a given conditions (other predicates). +func HasRequestTemplateWith(preds ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitor { + return predicate.ChannelMonitor(func(s *sql.Selector) { + step := newRequestTemplateStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ChannelMonitor) predicate.ChannelMonitor { + return predicate.ChannelMonitor(sql.NotPredicates(p)) +} diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go new file mode 100644 index 0000000000000000000000000000000000000000..2f70c300eb6ff8e83b0f88fcc209642a4cca4d03 --- /dev/null +++ b/backend/ent/channelmonitor_create.go @@ -0,0 +1,1610 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" +) + +// ChannelMonitorCreate is the builder for creating a ChannelMonitor entity. +type ChannelMonitorCreate struct { + config + mutation *ChannelMonitorMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *ChannelMonitorCreate) SetCreatedAt(v time.Time) *ChannelMonitorCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *ChannelMonitorCreate) SetUpdatedAt(v time.Time) *ChannelMonitorCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *ChannelMonitorCreate) SetName(v string) *ChannelMonitorCreate { + _c.mutation.SetName(v) + return _c +} + +// SetProvider sets the "provider" field. +func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelMonitorCreate { + _c.mutation.SetProvider(v) + return _c +} + +// SetEndpoint sets the "endpoint" field. +func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate { + _c.mutation.SetEndpoint(v) + return _c +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (_c *ChannelMonitorCreate) SetAPIKeyEncrypted(v string) *ChannelMonitorCreate { + _c.mutation.SetAPIKeyEncrypted(v) + return _c +} + +// SetPrimaryModel sets the "primary_model" field. +func (_c *ChannelMonitorCreate) SetPrimaryModel(v string) *ChannelMonitorCreate { + _c.mutation.SetPrimaryModel(v) + return _c +} + +// SetExtraModels sets the "extra_models" field. +func (_c *ChannelMonitorCreate) SetExtraModels(v []string) *ChannelMonitorCreate { + _c.mutation.SetExtraModels(v) + return _c +} + +// SetGroupName sets the "group_name" field. +func (_c *ChannelMonitorCreate) SetGroupName(v string) *ChannelMonitorCreate { + _c.mutation.SetGroupName(v) + return _c +} + +// SetNillableGroupName sets the "group_name" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableGroupName(v *string) *ChannelMonitorCreate { + if v != nil { + _c.SetGroupName(*v) + } + return _c +} + +// SetEnabled sets the "enabled" field. +func (_c *ChannelMonitorCreate) SetEnabled(v bool) *ChannelMonitorCreate { + _c.mutation.SetEnabled(v) + return _c +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableEnabled(v *bool) *ChannelMonitorCreate { + if v != nil { + _c.SetEnabled(*v) + } + return _c +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (_c *ChannelMonitorCreate) SetIntervalSeconds(v int) *ChannelMonitorCreate { + _c.mutation.SetIntervalSeconds(v) + return _c +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (_c *ChannelMonitorCreate) SetLastCheckedAt(v time.Time) *ChannelMonitorCreate { + _c.mutation.SetLastCheckedAt(v) + return _c +} + +// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorCreate { + if v != nil { + _c.SetLastCheckedAt(*v) + } + return _c +} + +// SetCreatedBy sets the "created_by" field. +func (_c *ChannelMonitorCreate) SetCreatedBy(v int64) *ChannelMonitorCreate { + _c.mutation.SetCreatedBy(v) + return _c +} + +// SetTemplateID sets the "template_id" field. +func (_c *ChannelMonitorCreate) SetTemplateID(v int64) *ChannelMonitorCreate { + _c.mutation.SetTemplateID(v) + return _c +} + +// SetNillableTemplateID sets the "template_id" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableTemplateID(v *int64) *ChannelMonitorCreate { + if v != nil { + _c.SetTemplateID(*v) + } + return _c +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_c *ChannelMonitorCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorCreate { + _c.mutation.SetExtraHeaders(v) + return _c +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_c *ChannelMonitorCreate) SetBodyOverrideMode(v string) *ChannelMonitorCreate { + _c.mutation.SetBodyOverrideMode(v) + return _c +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorCreate { + if v != nil { + _c.SetBodyOverrideMode(*v) + } + return _c +} + +// SetBodyOverride sets the "body_override" field. +func (_c *ChannelMonitorCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorCreate { + _c.mutation.SetBodyOverride(v) + return _c +} + +// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs. +func (_c *ChannelMonitorCreate) AddHistoryIDs(ids ...int64) *ChannelMonitorCreate { + _c.mutation.AddHistoryIDs(ids...) + return _c +} + +// AddHistory adds the "history" edges to the ChannelMonitorHistory entity. +func (_c *ChannelMonitorCreate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddHistoryIDs(ids...) +} + +// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs. +func (_c *ChannelMonitorCreate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorCreate { + _c.mutation.AddDailyRollupIDs(ids...) + return _c +} + +// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity. +func (_c *ChannelMonitorCreate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddDailyRollupIDs(ids...) +} + +// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID. +func (_c *ChannelMonitorCreate) SetRequestTemplateID(id int64) *ChannelMonitorCreate { + _c.mutation.SetRequestTemplateID(id) + return _c +} + +// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil. +func (_c *ChannelMonitorCreate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorCreate { + if id != nil { + _c = _c.SetRequestTemplateID(*id) + } + return _c +} + +// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (_c *ChannelMonitorCreate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorCreate { + return _c.SetRequestTemplateID(v.ID) +} + +// Mutation returns the ChannelMonitorMutation object of the builder. +func (_c *ChannelMonitorCreate) Mutation() *ChannelMonitorMutation { + return _c.mutation +} + +// Save creates the ChannelMonitor in the database. +func (_c *ChannelMonitorCreate) Save(ctx context.Context) (*ChannelMonitor, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ChannelMonitorCreate) SaveX(ctx context.Context) *ChannelMonitor { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ChannelMonitorCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := channelmonitor.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := channelmonitor.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.ExtraModels(); !ok { + v := channelmonitor.DefaultExtraModels + _c.mutation.SetExtraModels(v) + } + if _, ok := _c.mutation.GroupName(); !ok { + v := channelmonitor.DefaultGroupName + _c.mutation.SetGroupName(v) + } + if _, ok := _c.mutation.Enabled(); !ok { + v := channelmonitor.DefaultEnabled + _c.mutation.SetEnabled(v) + } + if _, ok := _c.mutation.ExtraHeaders(); !ok { + v := channelmonitor.DefaultExtraHeaders + _c.mutation.SetExtraHeaders(v) + } + if _, ok := _c.mutation.BodyOverrideMode(); !ok { + v := channelmonitor.DefaultBodyOverrideMode + _c.mutation.SetBodyOverrideMode(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ChannelMonitorCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitor.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitor.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitor.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := channelmonitor.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)} + } + } + if _, ok := _c.mutation.Provider(); !ok { + return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitor.provider"`)} + } + if v, ok := _c.mutation.Provider(); ok { + if err := channelmonitor.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)} + } + } + if _, ok := _c.mutation.Endpoint(); !ok { + return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)} + } + if v, ok := _c.mutation.Endpoint(); ok { + if err := channelmonitor.EndpointValidator(v); err != nil { + return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)} + } + } + if _, ok := _c.mutation.APIKeyEncrypted(); !ok { + return &ValidationError{Name: "api_key_encrypted", err: errors.New(`ent: missing required field "ChannelMonitor.api_key_encrypted"`)} + } + if v, ok := _c.mutation.APIKeyEncrypted(); ok { + if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil { + return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)} + } + } + if _, ok := _c.mutation.PrimaryModel(); !ok { + return &ValidationError{Name: "primary_model", err: errors.New(`ent: missing required field "ChannelMonitor.primary_model"`)} + } + if v, ok := _c.mutation.PrimaryModel(); ok { + if err := channelmonitor.PrimaryModelValidator(v); err != nil { + return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)} + } + } + if _, ok := _c.mutation.ExtraModels(); !ok { + return &ValidationError{Name: "extra_models", err: errors.New(`ent: missing required field "ChannelMonitor.extra_models"`)} + } + if v, ok := _c.mutation.GroupName(); ok { + if err := channelmonitor.GroupNameValidator(v); err != nil { + return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)} + } + } + if _, ok := _c.mutation.Enabled(); !ok { + return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ChannelMonitor.enabled"`)} + } + if _, ok := _c.mutation.IntervalSeconds(); !ok { + return &ValidationError{Name: "interval_seconds", err: errors.New(`ent: missing required field "ChannelMonitor.interval_seconds"`)} + } + if v, ok := _c.mutation.IntervalSeconds(); ok { + if err := channelmonitor.IntervalSecondsValidator(v); err != nil { + return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)} + } + } + if _, ok := _c.mutation.CreatedBy(); !ok { + return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "ChannelMonitor.created_by"`)} + } + if _, ok := _c.mutation.ExtraHeaders(); !ok { + return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitor.extra_headers"`)} + } + if _, ok := _c.mutation.BodyOverrideMode(); !ok { + return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitor.body_override_mode"`)} + } + if v, ok := _c.mutation.BodyOverrideMode(); ok { + if err := channelmonitor.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_c *ChannelMonitorCreate) sqlSave(ctx context.Context) (*ChannelMonitor, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateSpec) { + var ( + _node = &ChannelMonitor{config: _c.config} + _spec = sqlgraph.NewCreateSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(channelmonitor.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(channelmonitor.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Provider(); ok { + _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value) + _node.Provider = value + } + if value, ok := _c.mutation.Endpoint(); ok { + _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value) + _node.Endpoint = value + } + if value, ok := _c.mutation.APIKeyEncrypted(); ok { + _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value) + _node.APIKeyEncrypted = value + } + if value, ok := _c.mutation.PrimaryModel(); ok { + _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value) + _node.PrimaryModel = value + } + if value, ok := _c.mutation.ExtraModels(); ok { + _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value) + _node.ExtraModels = value + } + if value, ok := _c.mutation.GroupName(); ok { + _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value) + _node.GroupName = value + } + if value, ok := _c.mutation.Enabled(); ok { + _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value) + _node.Enabled = value + } + if value, ok := _c.mutation.IntervalSeconds(); ok { + _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value) + _node.IntervalSeconds = value + } + if value, ok := _c.mutation.LastCheckedAt(); ok { + _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value) + _node.LastCheckedAt = &value + } + if value, ok := _c.mutation.CreatedBy(); ok { + _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value) + _node.CreatedBy = value + } + if value, ok := _c.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value) + _node.ExtraHeaders = value + } + if value, ok := _c.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value) + _node.BodyOverrideMode = value + } + if value, ok := _c.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value) + _node.BodyOverride = value + } + if nodes := _c.mutation.HistoryIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.DailyRollupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.RequestTemplateIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: channelmonitor.RequestTemplateTable, + Columns: []string{channelmonitor.RequestTemplateColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TemplateID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitor.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertOne { + _c.conflict = opts + return &ChannelMonitorUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorCreate) OnConflictColumns(columns ...string) *ChannelMonitorUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorUpsertOne{ + create: _c, + } +} + +type ( + // ChannelMonitorUpsertOne is the builder for "upsert"-ing + // one ChannelMonitor node. + ChannelMonitorUpsertOne struct { + create *ChannelMonitorCreate + } + + // ChannelMonitorUpsert is the "OnConflict" setter. + ChannelMonitorUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateUpdatedAt() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *ChannelMonitorUpsert) SetName(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateName() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldName) + return u +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorUpsert) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldProvider, v) + return u +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldProvider) + return u +} + +// SetEndpoint sets the "endpoint" field. +func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldEndpoint, v) + return u +} + +// UpdateEndpoint sets the "endpoint" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateEndpoint() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldEndpoint) + return u +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (u *ChannelMonitorUpsert) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldAPIKeyEncrypted, v) + return u +} + +// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateAPIKeyEncrypted() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldAPIKeyEncrypted) + return u +} + +// SetPrimaryModel sets the "primary_model" field. +func (u *ChannelMonitorUpsert) SetPrimaryModel(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldPrimaryModel, v) + return u +} + +// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdatePrimaryModel() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldPrimaryModel) + return u +} + +// SetExtraModels sets the "extra_models" field. +func (u *ChannelMonitorUpsert) SetExtraModels(v []string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldExtraModels, v) + return u +} + +// UpdateExtraModels sets the "extra_models" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateExtraModels() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldExtraModels) + return u +} + +// SetGroupName sets the "group_name" field. +func (u *ChannelMonitorUpsert) SetGroupName(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldGroupName, v) + return u +} + +// UpdateGroupName sets the "group_name" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateGroupName() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldGroupName) + return u +} + +// ClearGroupName clears the value of the "group_name" field. +func (u *ChannelMonitorUpsert) ClearGroupName() *ChannelMonitorUpsert { + u.SetNull(channelmonitor.FieldGroupName) + return u +} + +// SetEnabled sets the "enabled" field. +func (u *ChannelMonitorUpsert) SetEnabled(v bool) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldEnabled, v) + return u +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateEnabled() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldEnabled) + return u +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (u *ChannelMonitorUpsert) SetIntervalSeconds(v int) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldIntervalSeconds, v) + return u +} + +// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateIntervalSeconds() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldIntervalSeconds) + return u +} + +// AddIntervalSeconds adds v to the "interval_seconds" field. +func (u *ChannelMonitorUpsert) AddIntervalSeconds(v int) *ChannelMonitorUpsert { + u.Add(channelmonitor.FieldIntervalSeconds, v) + return u +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (u *ChannelMonitorUpsert) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldLastCheckedAt, v) + return u +} + +// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateLastCheckedAt() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldLastCheckedAt) + return u +} + +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (u *ChannelMonitorUpsert) ClearLastCheckedAt() *ChannelMonitorUpsert { + u.SetNull(channelmonitor.FieldLastCheckedAt) + return u +} + +// SetCreatedBy sets the "created_by" field. +func (u *ChannelMonitorUpsert) SetCreatedBy(v int64) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldCreatedBy, v) + return u +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateCreatedBy() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldCreatedBy) + return u +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *ChannelMonitorUpsert) AddCreatedBy(v int64) *ChannelMonitorUpsert { + u.Add(channelmonitor.FieldCreatedBy, v) + return u +} + +// SetTemplateID sets the "template_id" field. +func (u *ChannelMonitorUpsert) SetTemplateID(v int64) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldTemplateID, v) + return u +} + +// UpdateTemplateID sets the "template_id" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateTemplateID() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldTemplateID) + return u +} + +// ClearTemplateID clears the value of the "template_id" field. +func (u *ChannelMonitorUpsert) ClearTemplateID() *ChannelMonitorUpsert { + u.SetNull(channelmonitor.FieldTemplateID) + return u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldExtraHeaders, v) + return u +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateExtraHeaders() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldExtraHeaders) + return u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorUpsert) SetBodyOverrideMode(v string) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldBodyOverrideMode, v) + return u +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateBodyOverrideMode() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldBodyOverrideMode) + return u +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsert { + u.Set(channelmonitor.FieldBodyOverride, v) + return u +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorUpsert) UpdateBodyOverride() *ChannelMonitorUpsert { + u.SetExcluded(channelmonitor.FieldBodyOverride) + return u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorUpsert) ClearBodyOverride() *ChannelMonitorUpsert { + u.SetNull(channelmonitor.FieldBodyOverride) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorUpsertOne) UpdateNewValues() *ChannelMonitorUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(channelmonitor.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorUpsertOne) Ignore() *ChannelMonitorUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorUpsertOne) DoNothing() *ChannelMonitorUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreate.OnConflict +// documentation for more info. +func (u *ChannelMonitorUpsertOne) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateUpdatedAt() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ChannelMonitorUpsertOne) SetName(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateName() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateName() + }) +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorUpsertOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetProvider(v) + }) +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateProvider() + }) +} + +// SetEndpoint sets the "endpoint" field. +func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetEndpoint(v) + }) +} + +// UpdateEndpoint sets the "endpoint" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateEndpoint() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateEndpoint() + }) +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (u *ChannelMonitorUpsertOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetAPIKeyEncrypted(v) + }) +} + +// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateAPIKeyEncrypted() + }) +} + +// SetPrimaryModel sets the "primary_model" field. +func (u *ChannelMonitorUpsertOne) SetPrimaryModel(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetPrimaryModel(v) + }) +} + +// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdatePrimaryModel() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdatePrimaryModel() + }) +} + +// SetExtraModels sets the "extra_models" field. +func (u *ChannelMonitorUpsertOne) SetExtraModels(v []string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetExtraModels(v) + }) +} + +// UpdateExtraModels sets the "extra_models" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateExtraModels() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateExtraModels() + }) +} + +// SetGroupName sets the "group_name" field. +func (u *ChannelMonitorUpsertOne) SetGroupName(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetGroupName(v) + }) +} + +// UpdateGroupName sets the "group_name" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateGroupName() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateGroupName() + }) +} + +// ClearGroupName clears the value of the "group_name" field. +func (u *ChannelMonitorUpsertOne) ClearGroupName() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearGroupName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ChannelMonitorUpsertOne) SetEnabled(v bool) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateEnabled() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateEnabled() + }) +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (u *ChannelMonitorUpsertOne) SetIntervalSeconds(v int) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetIntervalSeconds(v) + }) +} + +// AddIntervalSeconds adds v to the "interval_seconds" field. +func (u *ChannelMonitorUpsertOne) AddIntervalSeconds(v int) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.AddIntervalSeconds(v) + }) +} + +// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateIntervalSeconds() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateIntervalSeconds() + }) +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (u *ChannelMonitorUpsertOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetLastCheckedAt(v) + }) +} + +// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateLastCheckedAt() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateLastCheckedAt() + }) +} + +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (u *ChannelMonitorUpsertOne) ClearLastCheckedAt() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearLastCheckedAt() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *ChannelMonitorUpsertOne) SetCreatedBy(v int64) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetCreatedBy(v) + }) +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *ChannelMonitorUpsertOne) AddCreatedBy(v int64) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.AddCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateCreatedBy() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateCreatedBy() + }) +} + +// SetTemplateID sets the "template_id" field. +func (u *ChannelMonitorUpsertOne) SetTemplateID(v int64) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetTemplateID(v) + }) +} + +// UpdateTemplateID sets the "template_id" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateTemplateID() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateTemplateID() + }) +} + +// ClearTemplateID clears the value of the "template_id" field. +func (u *ChannelMonitorUpsertOne) ClearTemplateID() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearTemplateID() + }) +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetExtraHeaders(v) + }) +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateExtraHeaders() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateExtraHeaders() + }) +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetBodyOverrideMode(v) + }) +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateBodyOverrideMode() + }) +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetBodyOverride(v) + }) +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorUpsertOne) UpdateBodyOverride() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateBodyOverride() + }) +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorUpsertOne) ClearBodyOverride() *ChannelMonitorUpsertOne { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearBodyOverride() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ChannelMonitorUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ChannelMonitorUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ChannelMonitorCreateBulk is the builder for creating many ChannelMonitor entities in bulk. +type ChannelMonitorCreateBulk struct { + config + err error + builders []*ChannelMonitorCreate + conflict []sql.ConflictOption +} + +// Save creates the ChannelMonitor entities in the database. +func (_c *ChannelMonitorCreateBulk) Save(ctx context.Context) ([]*ChannelMonitor, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ChannelMonitor, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ChannelMonitorMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ChannelMonitorCreateBulk) SaveX(ctx context.Context) []*ChannelMonitor { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitor.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertBulk { + _c.conflict = opts + return &ChannelMonitorUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorUpsertBulk{ + create: _c, + } +} + +// ChannelMonitorUpsertBulk is the builder for "upsert"-ing +// a bulk of ChannelMonitor nodes. +type ChannelMonitorUpsertBulk struct { + create *ChannelMonitorCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorUpsertBulk) UpdateNewValues() *ChannelMonitorUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(channelmonitor.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitor.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorUpsertBulk) Ignore() *ChannelMonitorUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorUpsertBulk) DoNothing() *ChannelMonitorUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreateBulk.OnConflict +// documentation for more info. +func (u *ChannelMonitorUpsertBulk) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateUpdatedAt() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ChannelMonitorUpsertBulk) SetName(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateName() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateName() + }) +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorUpsertBulk) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetProvider(v) + }) +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateProvider() + }) +} + +// SetEndpoint sets the "endpoint" field. +func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetEndpoint(v) + }) +} + +// UpdateEndpoint sets the "endpoint" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateEndpoint() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateEndpoint() + }) +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (u *ChannelMonitorUpsertBulk) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetAPIKeyEncrypted(v) + }) +} + +// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateAPIKeyEncrypted() + }) +} + +// SetPrimaryModel sets the "primary_model" field. +func (u *ChannelMonitorUpsertBulk) SetPrimaryModel(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetPrimaryModel(v) + }) +} + +// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdatePrimaryModel() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdatePrimaryModel() + }) +} + +// SetExtraModels sets the "extra_models" field. +func (u *ChannelMonitorUpsertBulk) SetExtraModels(v []string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetExtraModels(v) + }) +} + +// UpdateExtraModels sets the "extra_models" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateExtraModels() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateExtraModels() + }) +} + +// SetGroupName sets the "group_name" field. +func (u *ChannelMonitorUpsertBulk) SetGroupName(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetGroupName(v) + }) +} + +// UpdateGroupName sets the "group_name" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateGroupName() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateGroupName() + }) +} + +// ClearGroupName clears the value of the "group_name" field. +func (u *ChannelMonitorUpsertBulk) ClearGroupName() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearGroupName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ChannelMonitorUpsertBulk) SetEnabled(v bool) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateEnabled() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateEnabled() + }) +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (u *ChannelMonitorUpsertBulk) SetIntervalSeconds(v int) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetIntervalSeconds(v) + }) +} + +// AddIntervalSeconds adds v to the "interval_seconds" field. +func (u *ChannelMonitorUpsertBulk) AddIntervalSeconds(v int) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.AddIntervalSeconds(v) + }) +} + +// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateIntervalSeconds() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateIntervalSeconds() + }) +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (u *ChannelMonitorUpsertBulk) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetLastCheckedAt(v) + }) +} + +// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateLastCheckedAt() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateLastCheckedAt() + }) +} + +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (u *ChannelMonitorUpsertBulk) ClearLastCheckedAt() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearLastCheckedAt() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *ChannelMonitorUpsertBulk) SetCreatedBy(v int64) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetCreatedBy(v) + }) +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *ChannelMonitorUpsertBulk) AddCreatedBy(v int64) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.AddCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateCreatedBy() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateCreatedBy() + }) +} + +// SetTemplateID sets the "template_id" field. +func (u *ChannelMonitorUpsertBulk) SetTemplateID(v int64) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetTemplateID(v) + }) +} + +// UpdateTemplateID sets the "template_id" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateTemplateID() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateTemplateID() + }) +} + +// ClearTemplateID clears the value of the "template_id" field. +func (u *ChannelMonitorUpsertBulk) ClearTemplateID() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearTemplateID() + }) +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetExtraHeaders(v) + }) +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateExtraHeaders() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateExtraHeaders() + }) +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetBodyOverrideMode(v) + }) +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateBodyOverrideMode() + }) +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.SetBodyOverride(v) + }) +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorUpsertBulk) UpdateBodyOverride() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.UpdateBodyOverride() + }) +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorUpsertBulk) ClearBodyOverride() *ChannelMonitorUpsertBulk { + return u.Update(func(s *ChannelMonitorUpsert) { + s.ClearBodyOverride() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitor_delete.go b/backend/ent/channelmonitor_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..500dbb4872ceca5808b474c708818064032fd2c6 --- /dev/null +++ b/backend/ent/channelmonitor_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorDelete is the builder for deleting a ChannelMonitor entity. +type ChannelMonitorDelete struct { + config + hooks []Hook + mutation *ChannelMonitorMutation +} + +// Where appends a list predicates to the ChannelMonitorDelete builder. +func (_d *ChannelMonitorDelete) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ChannelMonitorDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ChannelMonitorDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ChannelMonitorDeleteOne is the builder for deleting a single ChannelMonitor entity. +type ChannelMonitorDeleteOne struct { + _d *ChannelMonitorDelete +} + +// Where appends a list predicates to the ChannelMonitorDelete builder. +func (_d *ChannelMonitorDeleteOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ChannelMonitorDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{channelmonitor.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitor_query.go b/backend/ent/channelmonitor_query.go new file mode 100644 index 0000000000000000000000000000000000000000..b6722e7801bdf88407342d3099f870892741d11a --- /dev/null +++ b/backend/ent/channelmonitor_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorQuery is the builder for querying ChannelMonitor entities. +type ChannelMonitorQuery struct { + config + ctx *QueryContext + order []channelmonitor.OrderOption + inters []Interceptor + predicates []predicate.ChannelMonitor + withHistory *ChannelMonitorHistoryQuery + withDailyRollups *ChannelMonitorDailyRollupQuery + withRequestTemplate *ChannelMonitorRequestTemplateQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ChannelMonitorQuery builder. +func (_q *ChannelMonitorQuery) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ChannelMonitorQuery) Limit(limit int) *ChannelMonitorQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ChannelMonitorQuery) Offset(offset int) *ChannelMonitorQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ChannelMonitorQuery) Unique(unique bool) *ChannelMonitorQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ChannelMonitorQuery) Order(o ...channelmonitor.OrderOption) *ChannelMonitorQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryHistory chains the current query on the "history" edge. +func (_q *ChannelMonitorQuery) QueryHistory() *ChannelMonitorHistoryQuery { + query := (&ChannelMonitorHistoryClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector), + sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryDailyRollups chains the current query on the "daily_rollups" edge. +func (_q *ChannelMonitorQuery) QueryDailyRollups() *ChannelMonitorDailyRollupQuery { + query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector), + sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryRequestTemplate chains the current query on the "request_template" edge. +func (_q *ChannelMonitorQuery) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery { + query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector), + sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first ChannelMonitor entity from the query. +// Returns a *NotFoundError when no ChannelMonitor was found. +func (_q *ChannelMonitorQuery) First(ctx context.Context) (*ChannelMonitor, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{channelmonitor.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ChannelMonitorQuery) FirstX(ctx context.Context) *ChannelMonitor { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ChannelMonitor ID from the query. +// Returns a *NotFoundError when no ChannelMonitor ID was found. +func (_q *ChannelMonitorQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{channelmonitor.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ChannelMonitorQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ChannelMonitor entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ChannelMonitor entity is found. +// Returns a *NotFoundError when no ChannelMonitor entities are found. +func (_q *ChannelMonitorQuery) Only(ctx context.Context) (*ChannelMonitor, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{channelmonitor.Label} + default: + return nil, &NotSingularError{channelmonitor.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ChannelMonitorQuery) OnlyX(ctx context.Context) *ChannelMonitor { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ChannelMonitor ID in the query. +// Returns a *NotSingularError when more than one ChannelMonitor ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ChannelMonitorQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{channelmonitor.Label} + default: + err = &NotSingularError{channelmonitor.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ChannelMonitorQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ChannelMonitors. +func (_q *ChannelMonitorQuery) All(ctx context.Context) ([]*ChannelMonitor, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ChannelMonitor, *ChannelMonitorQuery]() + return withInterceptors[[]*ChannelMonitor](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ChannelMonitorQuery) AllX(ctx context.Context) []*ChannelMonitor { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ChannelMonitor IDs. +func (_q *ChannelMonitorQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(channelmonitor.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ChannelMonitorQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ChannelMonitorQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ChannelMonitorQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ChannelMonitorQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ChannelMonitorQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ChannelMonitorQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ChannelMonitorQuery) Clone() *ChannelMonitorQuery { + if _q == nil { + return nil + } + return &ChannelMonitorQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]channelmonitor.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ChannelMonitor{}, _q.predicates...), + withHistory: _q.withHistory.Clone(), + withDailyRollups: _q.withDailyRollups.Clone(), + withRequestTemplate: _q.withRequestTemplate.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithHistory tells the query-builder to eager-load the nodes that are connected to +// the "history" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorQuery) WithHistory(opts ...func(*ChannelMonitorHistoryQuery)) *ChannelMonitorQuery { + query := (&ChannelMonitorHistoryClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withHistory = query + return _q +} + +// WithDailyRollups tells the query-builder to eager-load the nodes that are connected to +// the "daily_rollups" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorQuery) WithDailyRollups(opts ...func(*ChannelMonitorDailyRollupQuery)) *ChannelMonitorQuery { + query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withDailyRollups = query + return _q +} + +// WithRequestTemplate tells the query-builder to eager-load the nodes that are connected to +// the "request_template" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorQuery) WithRequestTemplate(opts ...func(*ChannelMonitorRequestTemplateQuery)) *ChannelMonitorQuery { + query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withRequestTemplate = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ChannelMonitor.Query(). +// GroupBy(channelmonitor.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ChannelMonitorQuery) GroupBy(field string, fields ...string) *ChannelMonitorGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ChannelMonitorGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = channelmonitor.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.ChannelMonitor.Query(). +// Select(channelmonitor.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *ChannelMonitorQuery) Select(fields ...string) *ChannelMonitorSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ChannelMonitorSelect{ChannelMonitorQuery: _q} + sbuild.label = channelmonitor.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ChannelMonitorSelect configured with the given aggregations. +func (_q *ChannelMonitorQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ChannelMonitorQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !channelmonitor.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitor, error) { + var ( + nodes = []*ChannelMonitor{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withHistory != nil, + _q.withDailyRollups != nil, + _q.withRequestTemplate != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ChannelMonitor).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ChannelMonitor{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withHistory; query != nil { + if err := _q.loadHistory(ctx, query, nodes, + func(n *ChannelMonitor) { n.Edges.History = []*ChannelMonitorHistory{} }, + func(n *ChannelMonitor, e *ChannelMonitorHistory) { n.Edges.History = append(n.Edges.History, e) }); err != nil { + return nil, err + } + } + if query := _q.withDailyRollups; query != nil { + if err := _q.loadDailyRollups(ctx, query, nodes, + func(n *ChannelMonitor) { n.Edges.DailyRollups = []*ChannelMonitorDailyRollup{} }, + func(n *ChannelMonitor, e *ChannelMonitorDailyRollup) { + n.Edges.DailyRollups = append(n.Edges.DailyRollups, e) + }); err != nil { + return nil, err + } + } + if query := _q.withRequestTemplate; query != nil { + if err := _q.loadRequestTemplate(ctx, query, nodes, nil, + func(n *ChannelMonitor, e *ChannelMonitorRequestTemplate) { n.Edges.RequestTemplate = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *ChannelMonitorQuery) loadHistory(ctx context.Context, query *ChannelMonitorHistoryQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorHistory)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*ChannelMonitor) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(channelmonitorhistory.FieldMonitorID) + } + query.Where(predicate.ChannelMonitorHistory(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(channelmonitor.HistoryColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.MonitorID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *ChannelMonitorQuery) loadDailyRollups(ctx context.Context, query *ChannelMonitorDailyRollupQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorDailyRollup)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*ChannelMonitor) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(channelmonitordailyrollup.FieldMonitorID) + } + query.Where(predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(channelmonitor.DailyRollupsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.MonitorID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *ChannelMonitorQuery) loadRequestTemplate(ctx context.Context, query *ChannelMonitorRequestTemplateQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorRequestTemplate)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*ChannelMonitor) + for i := range nodes { + if nodes[i].TemplateID == nil { + continue + } + fk := *nodes[i].TemplateID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(channelmonitorrequesttemplate.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "template_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *ChannelMonitorQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ChannelMonitorQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID) + for i := range fields { + if fields[i] != channelmonitor.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withRequestTemplate != nil { + _spec.Node.AddColumnOnce(channelmonitor.FieldTemplateID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ChannelMonitorQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(channelmonitor.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = channelmonitor.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ChannelMonitorQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ChannelMonitorQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ChannelMonitorGroupBy is the group-by builder for ChannelMonitor entities. +type ChannelMonitorGroupBy struct { + selector + build *ChannelMonitorQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ChannelMonitorGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ChannelMonitorGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ChannelMonitorGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ChannelMonitorSelect is the builder for selecting fields of ChannelMonitor entities. +type ChannelMonitorSelect struct { + *ChannelMonitorQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ChannelMonitorSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ChannelMonitorSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorSelect](ctx, _s.ChannelMonitorQuery, _s, _s.inters, v) +} + +func (_s *ChannelMonitorSelect) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go new file mode 100644 index 0000000000000000000000000000000000000000..4bbcd564e447bfb81fbea35ff59330def8a2937e --- /dev/null +++ b/backend/ent/channelmonitor_update.go @@ -0,0 +1,1328 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorUpdate is the builder for updating ChannelMonitor entities. +type ChannelMonitorUpdate struct { + config + hooks []Hook + mutation *ChannelMonitorMutation +} + +// Where appends a list predicates to the ChannelMonitorUpdate builder. +func (_u *ChannelMonitorUpdate) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ChannelMonitorUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ChannelMonitorUpdate) SetName(v string) *ChannelMonitorUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableName(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetProvider sets the "provider" field. +func (_u *ChannelMonitorUpdate) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdate { + _u.mutation.SetProvider(v) + return _u +} + +// SetNillableProvider sets the "provider" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdate { + if v != nil { + _u.SetProvider(*v) + } + return _u +} + +// SetEndpoint sets the "endpoint" field. +func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate { + _u.mutation.SetEndpoint(v) + return _u +} + +// SetNillableEndpoint sets the "endpoint" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableEndpoint(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetEndpoint(*v) + } + return _u +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (_u *ChannelMonitorUpdate) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdate { + _u.mutation.SetAPIKeyEncrypted(v) + return _u +} + +// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetAPIKeyEncrypted(*v) + } + return _u +} + +// SetPrimaryModel sets the "primary_model" field. +func (_u *ChannelMonitorUpdate) SetPrimaryModel(v string) *ChannelMonitorUpdate { + _u.mutation.SetPrimaryModel(v) + return _u +} + +// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetPrimaryModel(*v) + } + return _u +} + +// SetExtraModels sets the "extra_models" field. +func (_u *ChannelMonitorUpdate) SetExtraModels(v []string) *ChannelMonitorUpdate { + _u.mutation.SetExtraModels(v) + return _u +} + +// AppendExtraModels appends value to the "extra_models" field. +func (_u *ChannelMonitorUpdate) AppendExtraModels(v []string) *ChannelMonitorUpdate { + _u.mutation.AppendExtraModels(v) + return _u +} + +// SetGroupName sets the "group_name" field. +func (_u *ChannelMonitorUpdate) SetGroupName(v string) *ChannelMonitorUpdate { + _u.mutation.SetGroupName(v) + return _u +} + +// SetNillableGroupName sets the "group_name" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableGroupName(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetGroupName(*v) + } + return _u +} + +// ClearGroupName clears the value of the "group_name" field. +func (_u *ChannelMonitorUpdate) ClearGroupName() *ChannelMonitorUpdate { + _u.mutation.ClearGroupName() + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ChannelMonitorUpdate) SetEnabled(v bool) *ChannelMonitorUpdate { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableEnabled(v *bool) *ChannelMonitorUpdate { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (_u *ChannelMonitorUpdate) SetIntervalSeconds(v int) *ChannelMonitorUpdate { + _u.mutation.ResetIntervalSeconds() + _u.mutation.SetIntervalSeconds(v) + return _u +} + +// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdate { + if v != nil { + _u.SetIntervalSeconds(*v) + } + return _u +} + +// AddIntervalSeconds adds value to the "interval_seconds" field. +func (_u *ChannelMonitorUpdate) AddIntervalSeconds(v int) *ChannelMonitorUpdate { + _u.mutation.AddIntervalSeconds(v) + return _u +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (_u *ChannelMonitorUpdate) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdate { + _u.mutation.SetLastCheckedAt(v) + return _u +} + +// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdate { + if v != nil { + _u.SetLastCheckedAt(*v) + } + return _u +} + +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (_u *ChannelMonitorUpdate) ClearLastCheckedAt() *ChannelMonitorUpdate { + _u.mutation.ClearLastCheckedAt() + return _u +} + +// SetCreatedBy sets the "created_by" field. +func (_u *ChannelMonitorUpdate) SetCreatedBy(v int64) *ChannelMonitorUpdate { + _u.mutation.ResetCreatedBy() + _u.mutation.SetCreatedBy(v) + return _u +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdate { + if v != nil { + _u.SetCreatedBy(*v) + } + return _u +} + +// AddCreatedBy adds value to the "created_by" field. +func (_u *ChannelMonitorUpdate) AddCreatedBy(v int64) *ChannelMonitorUpdate { + _u.mutation.AddCreatedBy(v) + return _u +} + +// SetTemplateID sets the "template_id" field. +func (_u *ChannelMonitorUpdate) SetTemplateID(v int64) *ChannelMonitorUpdate { + _u.mutation.SetTemplateID(v) + return _u +} + +// SetNillableTemplateID sets the "template_id" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableTemplateID(v *int64) *ChannelMonitorUpdate { + if v != nil { + _u.SetTemplateID(*v) + } + return _u +} + +// ClearTemplateID clears the value of the "template_id" field. +func (_u *ChannelMonitorUpdate) ClearTemplateID() *ChannelMonitorUpdate { + _u.mutation.ClearTemplateID() + return _u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_u *ChannelMonitorUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdate { + _u.mutation.SetExtraHeaders(v) + return _u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_u *ChannelMonitorUpdate) SetBodyOverrideMode(v string) *ChannelMonitorUpdate { + _u.mutation.SetBodyOverrideMode(v) + return _u +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdate { + if v != nil { + _u.SetBodyOverrideMode(*v) + } + return _u +} + +// SetBodyOverride sets the "body_override" field. +func (_u *ChannelMonitorUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdate { + _u.mutation.SetBodyOverride(v) + return _u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (_u *ChannelMonitorUpdate) ClearBodyOverride() *ChannelMonitorUpdate { + _u.mutation.ClearBodyOverride() + return _u +} + +// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs. +func (_u *ChannelMonitorUpdate) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdate { + _u.mutation.AddHistoryIDs(ids...) + return _u +} + +// AddHistory adds the "history" edges to the ChannelMonitorHistory entity. +func (_u *ChannelMonitorUpdate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddHistoryIDs(ids...) +} + +// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs. +func (_u *ChannelMonitorUpdate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate { + _u.mutation.AddDailyRollupIDs(ids...) + return _u +} + +// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity. +func (_u *ChannelMonitorUpdate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddDailyRollupIDs(ids...) +} + +// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID. +func (_u *ChannelMonitorUpdate) SetRequestTemplateID(id int64) *ChannelMonitorUpdate { + _u.mutation.SetRequestTemplateID(id) + return _u +} + +// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil. +func (_u *ChannelMonitorUpdate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdate { + if id != nil { + _u = _u.SetRequestTemplateID(*id) + } + return _u +} + +// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (_u *ChannelMonitorUpdate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdate { + return _u.SetRequestTemplateID(v.ID) +} + +// Mutation returns the ChannelMonitorMutation object of the builder. +func (_u *ChannelMonitorUpdate) Mutation() *ChannelMonitorMutation { + return _u.mutation +} + +// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity. +func (_u *ChannelMonitorUpdate) ClearHistory() *ChannelMonitorUpdate { + _u.mutation.ClearHistory() + return _u +} + +// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs. +func (_u *ChannelMonitorUpdate) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdate { + _u.mutation.RemoveHistoryIDs(ids...) + return _u +} + +// RemoveHistory removes "history" edges to ChannelMonitorHistory entities. +func (_u *ChannelMonitorUpdate) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveHistoryIDs(ids...) +} + +// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity. +func (_u *ChannelMonitorUpdate) ClearDailyRollups() *ChannelMonitorUpdate { + _u.mutation.ClearDailyRollups() + return _u +} + +// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs. +func (_u *ChannelMonitorUpdate) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate { + _u.mutation.RemoveDailyRollupIDs(ids...) + return _u +} + +// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities. +func (_u *ChannelMonitorUpdate) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveDailyRollupIDs(ids...) +} + +// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (_u *ChannelMonitorUpdate) ClearRequestTemplate() *ChannelMonitorUpdate { + _u.mutation.ClearRequestTemplate() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ChannelMonitorUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ChannelMonitorUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := channelmonitor.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := channelmonitor.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)} + } + } + if v, ok := _u.mutation.Provider(); ok { + if err := channelmonitor.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)} + } + } + if v, ok := _u.mutation.Endpoint(); ok { + if err := channelmonitor.EndpointValidator(v); err != nil { + return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)} + } + } + if v, ok := _u.mutation.APIKeyEncrypted(); ok { + if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil { + return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)} + } + } + if v, ok := _u.mutation.PrimaryModel(); ok { + if err := channelmonitor.PrimaryModelValidator(v); err != nil { + return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)} + } + } + if v, ok := _u.mutation.GroupName(); ok { + if err := channelmonitor.GroupNameValidator(v); err != nil { + return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)} + } + } + if v, ok := _u.mutation.IntervalSeconds(); ok { + if err := channelmonitor.IntervalSecondsValidator(v); err != nil { + return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)} + } + } + if v, ok := _u.mutation.BodyOverrideMode(); ok { + if err := channelmonitor.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(channelmonitor.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Provider(); ok { + _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value) + } + if value, ok := _u.mutation.Endpoint(); ok { + _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value) + } + if value, ok := _u.mutation.APIKeyEncrypted(); ok { + _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value) + } + if value, ok := _u.mutation.PrimaryModel(); ok { + _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value) + } + if value, ok := _u.mutation.ExtraModels(); ok { + _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedExtraModels(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, channelmonitor.FieldExtraModels, value) + }) + } + if value, ok := _u.mutation.GroupName(); ok { + _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value) + } + if _u.mutation.GroupNameCleared() { + _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.IntervalSeconds(); ok { + _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedIntervalSeconds(); ok { + _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value) + } + if value, ok := _u.mutation.LastCheckedAt(); ok { + _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value) + } + if _u.mutation.LastCheckedAtCleared() { + _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime) + } + if value, ok := _u.mutation.CreatedBy(); ok { + _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCreatedBy(); ok { + _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value) + } + if value, ok := _u.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value) + } + if value, ok := _u.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value) + } + if _u.mutation.BodyOverrideCleared() { + _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON) + } + if _u.mutation.HistoryCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.DailyRollupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.RequestTemplateCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: channelmonitor.RequestTemplateTable, + Columns: []string{channelmonitor.RequestTemplateColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: channelmonitor.RequestTemplateTable, + Columns: []string{channelmonitor.RequestTemplateColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitor.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ChannelMonitorUpdateOne is the builder for updating a single ChannelMonitor entity. +type ChannelMonitorUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ChannelMonitorMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ChannelMonitorUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ChannelMonitorUpdateOne) SetName(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableName(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetProvider sets the "provider" field. +func (_u *ChannelMonitorUpdateOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdateOne { + _u.mutation.SetProvider(v) + return _u +} + +// SetNillableProvider sets the "provider" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetProvider(*v) + } + return _u +} + +// SetEndpoint sets the "endpoint" field. +func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetEndpoint(v) + return _u +} + +// SetNillableEndpoint sets the "endpoint" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableEndpoint(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetEndpoint(*v) + } + return _u +} + +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (_u *ChannelMonitorUpdateOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetAPIKeyEncrypted(v) + return _u +} + +// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetAPIKeyEncrypted(*v) + } + return _u +} + +// SetPrimaryModel sets the "primary_model" field. +func (_u *ChannelMonitorUpdateOne) SetPrimaryModel(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetPrimaryModel(v) + return _u +} + +// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetPrimaryModel(*v) + } + return _u +} + +// SetExtraModels sets the "extra_models" field. +func (_u *ChannelMonitorUpdateOne) SetExtraModels(v []string) *ChannelMonitorUpdateOne { + _u.mutation.SetExtraModels(v) + return _u +} + +// AppendExtraModels appends value to the "extra_models" field. +func (_u *ChannelMonitorUpdateOne) AppendExtraModels(v []string) *ChannelMonitorUpdateOne { + _u.mutation.AppendExtraModels(v) + return _u +} + +// SetGroupName sets the "group_name" field. +func (_u *ChannelMonitorUpdateOne) SetGroupName(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetGroupName(v) + return _u +} + +// SetNillableGroupName sets the "group_name" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableGroupName(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetGroupName(*v) + } + return _u +} + +// ClearGroupName clears the value of the "group_name" field. +func (_u *ChannelMonitorUpdateOne) ClearGroupName() *ChannelMonitorUpdateOne { + _u.mutation.ClearGroupName() + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ChannelMonitorUpdateOne) SetEnabled(v bool) *ChannelMonitorUpdateOne { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableEnabled(v *bool) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetIntervalSeconds sets the "interval_seconds" field. +func (_u *ChannelMonitorUpdateOne) SetIntervalSeconds(v int) *ChannelMonitorUpdateOne { + _u.mutation.ResetIntervalSeconds() + _u.mutation.SetIntervalSeconds(v) + return _u +} + +// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetIntervalSeconds(*v) + } + return _u +} + +// AddIntervalSeconds adds value to the "interval_seconds" field. +func (_u *ChannelMonitorUpdateOne) AddIntervalSeconds(v int) *ChannelMonitorUpdateOne { + _u.mutation.AddIntervalSeconds(v) + return _u +} + +// SetLastCheckedAt sets the "last_checked_at" field. +func (_u *ChannelMonitorUpdateOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdateOne { + _u.mutation.SetLastCheckedAt(v) + return _u +} + +// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetLastCheckedAt(*v) + } + return _u +} + +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (_u *ChannelMonitorUpdateOne) ClearLastCheckedAt() *ChannelMonitorUpdateOne { + _u.mutation.ClearLastCheckedAt() + return _u +} + +// SetCreatedBy sets the "created_by" field. +func (_u *ChannelMonitorUpdateOne) SetCreatedBy(v int64) *ChannelMonitorUpdateOne { + _u.mutation.ResetCreatedBy() + _u.mutation.SetCreatedBy(v) + return _u +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetCreatedBy(*v) + } + return _u +} + +// AddCreatedBy adds value to the "created_by" field. +func (_u *ChannelMonitorUpdateOne) AddCreatedBy(v int64) *ChannelMonitorUpdateOne { + _u.mutation.AddCreatedBy(v) + return _u +} + +// SetTemplateID sets the "template_id" field. +func (_u *ChannelMonitorUpdateOne) SetTemplateID(v int64) *ChannelMonitorUpdateOne { + _u.mutation.SetTemplateID(v) + return _u +} + +// SetNillableTemplateID sets the "template_id" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableTemplateID(v *int64) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetTemplateID(*v) + } + return _u +} + +// ClearTemplateID clears the value of the "template_id" field. +func (_u *ChannelMonitorUpdateOne) ClearTemplateID() *ChannelMonitorUpdateOne { + _u.mutation.ClearTemplateID() + return _u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_u *ChannelMonitorUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdateOne { + _u.mutation.SetExtraHeaders(v) + return _u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_u *ChannelMonitorUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorUpdateOne { + _u.mutation.SetBodyOverrideMode(v) + return _u +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdateOne { + if v != nil { + _u.SetBodyOverrideMode(*v) + } + return _u +} + +// SetBodyOverride sets the "body_override" field. +func (_u *ChannelMonitorUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdateOne { + _u.mutation.SetBodyOverride(v) + return _u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (_u *ChannelMonitorUpdateOne) ClearBodyOverride() *ChannelMonitorUpdateOne { + _u.mutation.ClearBodyOverride() + return _u +} + +// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs. +func (_u *ChannelMonitorUpdateOne) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne { + _u.mutation.AddHistoryIDs(ids...) + return _u +} + +// AddHistory adds the "history" edges to the ChannelMonitorHistory entity. +func (_u *ChannelMonitorUpdateOne) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddHistoryIDs(ids...) +} + +// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs. +func (_u *ChannelMonitorUpdateOne) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne { + _u.mutation.AddDailyRollupIDs(ids...) + return _u +} + +// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity. +func (_u *ChannelMonitorUpdateOne) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddDailyRollupIDs(ids...) +} + +// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID. +func (_u *ChannelMonitorUpdateOne) SetRequestTemplateID(id int64) *ChannelMonitorUpdateOne { + _u.mutation.SetRequestTemplateID(id) + return _u +} + +// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil. +func (_u *ChannelMonitorUpdateOne) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdateOne { + if id != nil { + _u = _u.SetRequestTemplateID(*id) + } + return _u +} + +// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (_u *ChannelMonitorUpdateOne) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdateOne { + return _u.SetRequestTemplateID(v.ID) +} + +// Mutation returns the ChannelMonitorMutation object of the builder. +func (_u *ChannelMonitorUpdateOne) Mutation() *ChannelMonitorMutation { + return _u.mutation +} + +// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity. +func (_u *ChannelMonitorUpdateOne) ClearHistory() *ChannelMonitorUpdateOne { + _u.mutation.ClearHistory() + return _u +} + +// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs. +func (_u *ChannelMonitorUpdateOne) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne { + _u.mutation.RemoveHistoryIDs(ids...) + return _u +} + +// RemoveHistory removes "history" edges to ChannelMonitorHistory entities. +func (_u *ChannelMonitorUpdateOne) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveHistoryIDs(ids...) +} + +// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity. +func (_u *ChannelMonitorUpdateOne) ClearDailyRollups() *ChannelMonitorUpdateOne { + _u.mutation.ClearDailyRollups() + return _u +} + +// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs. +func (_u *ChannelMonitorUpdateOne) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne { + _u.mutation.RemoveDailyRollupIDs(ids...) + return _u +} + +// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities. +func (_u *ChannelMonitorUpdateOne) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveDailyRollupIDs(ids...) +} + +// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (_u *ChannelMonitorUpdateOne) ClearRequestTemplate() *ChannelMonitorUpdateOne { + _u.mutation.ClearRequestTemplate() + return _u +} + +// Where appends a list predicates to the ChannelMonitorUpdate builder. +func (_u *ChannelMonitorUpdateOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ChannelMonitorUpdateOne) Select(field string, fields ...string) *ChannelMonitorUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ChannelMonitor entity. +func (_u *ChannelMonitorUpdateOne) Save(ctx context.Context) (*ChannelMonitor, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorUpdateOne) SaveX(ctx context.Context) *ChannelMonitor { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ChannelMonitorUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := channelmonitor.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := channelmonitor.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)} + } + } + if v, ok := _u.mutation.Provider(); ok { + if err := channelmonitor.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)} + } + } + if v, ok := _u.mutation.Endpoint(); ok { + if err := channelmonitor.EndpointValidator(v); err != nil { + return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)} + } + } + if v, ok := _u.mutation.APIKeyEncrypted(); ok { + if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil { + return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)} + } + } + if v, ok := _u.mutation.PrimaryModel(); ok { + if err := channelmonitor.PrimaryModelValidator(v); err != nil { + return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)} + } + } + if v, ok := _u.mutation.GroupName(); ok { + if err := channelmonitor.GroupNameValidator(v); err != nil { + return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)} + } + } + if v, ok := _u.mutation.IntervalSeconds(); ok { + if err := channelmonitor.IntervalSecondsValidator(v); err != nil { + return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)} + } + } + if v, ok := _u.mutation.BodyOverrideMode(); ok { + if err := channelmonitor.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitor, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitor.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID) + for _, f := range fields { + if !channelmonitor.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != channelmonitor.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(channelmonitor.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Provider(); ok { + _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value) + } + if value, ok := _u.mutation.Endpoint(); ok { + _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value) + } + if value, ok := _u.mutation.APIKeyEncrypted(); ok { + _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value) + } + if value, ok := _u.mutation.PrimaryModel(); ok { + _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value) + } + if value, ok := _u.mutation.ExtraModels(); ok { + _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedExtraModels(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, channelmonitor.FieldExtraModels, value) + }) + } + if value, ok := _u.mutation.GroupName(); ok { + _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value) + } + if _u.mutation.GroupNameCleared() { + _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.IntervalSeconds(); ok { + _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedIntervalSeconds(); ok { + _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value) + } + if value, ok := _u.mutation.LastCheckedAt(); ok { + _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value) + } + if _u.mutation.LastCheckedAtCleared() { + _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime) + } + if value, ok := _u.mutation.CreatedBy(); ok { + _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCreatedBy(); ok { + _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value) + } + if value, ok := _u.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value) + } + if value, ok := _u.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value) + } + if _u.mutation.BodyOverrideCleared() { + _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON) + } + if _u.mutation.HistoryCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.HistoryTable, + Columns: []string{channelmonitor.HistoryColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.DailyRollupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: channelmonitor.DailyRollupsTable, + Columns: []string{channelmonitor.DailyRollupsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.RequestTemplateCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: channelmonitor.RequestTemplateTable, + Columns: []string{channelmonitor.RequestTemplateColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: channelmonitor.RequestTemplateTable, + Columns: []string{channelmonitor.RequestTemplateColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &ChannelMonitor{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitor.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup.go new file mode 100644 index 0000000000000000000000000000000000000000..78a5f48916d62fe0e89b739aa0271f2f580bcfef --- /dev/null +++ b/backend/ent/channelmonitordailyrollup.go @@ -0,0 +1,278 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" +) + +// ChannelMonitorDailyRollup is the model entity for the ChannelMonitorDailyRollup schema. +type ChannelMonitorDailyRollup struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // MonitorID holds the value of the "monitor_id" field. + MonitorID int64 `json:"monitor_id,omitempty"` + // Model holds the value of the "model" field. + Model string `json:"model,omitempty"` + // BucketDate holds the value of the "bucket_date" field. + BucketDate time.Time `json:"bucket_date,omitempty"` + // TotalChecks holds the value of the "total_checks" field. + TotalChecks int `json:"total_checks,omitempty"` + // OkCount holds the value of the "ok_count" field. + OkCount int `json:"ok_count,omitempty"` + // OperationalCount holds the value of the "operational_count" field. + OperationalCount int `json:"operational_count,omitempty"` + // DegradedCount holds the value of the "degraded_count" field. + DegradedCount int `json:"degraded_count,omitempty"` + // FailedCount holds the value of the "failed_count" field. + FailedCount int `json:"failed_count,omitempty"` + // ErrorCount holds the value of the "error_count" field. + ErrorCount int `json:"error_count,omitempty"` + // SumLatencyMs holds the value of the "sum_latency_ms" field. + SumLatencyMs int64 `json:"sum_latency_ms,omitempty"` + // CountLatency holds the value of the "count_latency" field. + CountLatency int `json:"count_latency,omitempty"` + // SumPingLatencyMs holds the value of the "sum_ping_latency_ms" field. + SumPingLatencyMs int64 `json:"sum_ping_latency_ms,omitempty"` + // CountPingLatency holds the value of the "count_ping_latency" field. + CountPingLatency int `json:"count_ping_latency,omitempty"` + // ComputedAt holds the value of the "computed_at" field. + ComputedAt time.Time `json:"computed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ChannelMonitorDailyRollupQuery when eager-loading is set. + Edges ChannelMonitorDailyRollupEdges `json:"edges"` + selectValues sql.SelectValues +} + +// ChannelMonitorDailyRollupEdges holds the relations/edges for other nodes in the graph. +type ChannelMonitorDailyRollupEdges struct { + // Monitor holds the value of the monitor edge. + Monitor *ChannelMonitor `json:"monitor,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// MonitorOrErr returns the Monitor value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e ChannelMonitorDailyRollupEdges) MonitorOrErr() (*ChannelMonitor, error) { + if e.Monitor != nil { + return e.Monitor, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: channelmonitor.Label} + } + return nil, &NotLoadedError{edge: "monitor"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ChannelMonitorDailyRollup) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case channelmonitordailyrollup.FieldID, channelmonitordailyrollup.FieldMonitorID, channelmonitordailyrollup.FieldTotalChecks, channelmonitordailyrollup.FieldOkCount, channelmonitordailyrollup.FieldOperationalCount, channelmonitordailyrollup.FieldDegradedCount, channelmonitordailyrollup.FieldFailedCount, channelmonitordailyrollup.FieldErrorCount, channelmonitordailyrollup.FieldSumLatencyMs, channelmonitordailyrollup.FieldCountLatency, channelmonitordailyrollup.FieldSumPingLatencyMs, channelmonitordailyrollup.FieldCountPingLatency: + values[i] = new(sql.NullInt64) + case channelmonitordailyrollup.FieldModel: + values[i] = new(sql.NullString) + case channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ChannelMonitorDailyRollup fields. +func (_m *ChannelMonitorDailyRollup) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case channelmonitordailyrollup.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case channelmonitordailyrollup.FieldMonitorID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field monitor_id", values[i]) + } else if value.Valid { + _m.MonitorID = value.Int64 + } + case channelmonitordailyrollup.FieldModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[i]) + } else if value.Valid { + _m.Model = value.String + } + case channelmonitordailyrollup.FieldBucketDate: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field bucket_date", values[i]) + } else if value.Valid { + _m.BucketDate = value.Time + } + case channelmonitordailyrollup.FieldTotalChecks: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field total_checks", values[i]) + } else if value.Valid { + _m.TotalChecks = int(value.Int64) + } + case channelmonitordailyrollup.FieldOkCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field ok_count", values[i]) + } else if value.Valid { + _m.OkCount = int(value.Int64) + } + case channelmonitordailyrollup.FieldOperationalCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field operational_count", values[i]) + } else if value.Valid { + _m.OperationalCount = int(value.Int64) + } + case channelmonitordailyrollup.FieldDegradedCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field degraded_count", values[i]) + } else if value.Valid { + _m.DegradedCount = int(value.Int64) + } + case channelmonitordailyrollup.FieldFailedCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field failed_count", values[i]) + } else if value.Valid { + _m.FailedCount = int(value.Int64) + } + case channelmonitordailyrollup.FieldErrorCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field error_count", values[i]) + } else if value.Valid { + _m.ErrorCount = int(value.Int64) + } + case channelmonitordailyrollup.FieldSumLatencyMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sum_latency_ms", values[i]) + } else if value.Valid { + _m.SumLatencyMs = value.Int64 + } + case channelmonitordailyrollup.FieldCountLatency: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field count_latency", values[i]) + } else if value.Valid { + _m.CountLatency = int(value.Int64) + } + case channelmonitordailyrollup.FieldSumPingLatencyMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sum_ping_latency_ms", values[i]) + } else if value.Valid { + _m.SumPingLatencyMs = value.Int64 + } + case channelmonitordailyrollup.FieldCountPingLatency: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field count_ping_latency", values[i]) + } else if value.Valid { + _m.CountPingLatency = int(value.Int64) + } + case channelmonitordailyrollup.FieldComputedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field computed_at", values[i]) + } else if value.Valid { + _m.ComputedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorDailyRollup. +// This includes values selected through modifiers, order, etc. +func (_m *ChannelMonitorDailyRollup) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryMonitor queries the "monitor" edge of the ChannelMonitorDailyRollup entity. +func (_m *ChannelMonitorDailyRollup) QueryMonitor() *ChannelMonitorQuery { + return NewChannelMonitorDailyRollupClient(_m.config).QueryMonitor(_m) +} + +// Update returns a builder for updating this ChannelMonitorDailyRollup. +// Note that you need to call ChannelMonitorDailyRollup.Unwrap() before calling this method if this ChannelMonitorDailyRollup +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ChannelMonitorDailyRollup) Update() *ChannelMonitorDailyRollupUpdateOne { + return NewChannelMonitorDailyRollupClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ChannelMonitorDailyRollup entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ChannelMonitorDailyRollup) Unwrap() *ChannelMonitorDailyRollup { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ChannelMonitorDailyRollup is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ChannelMonitorDailyRollup) String() string { + var builder strings.Builder + builder.WriteString("ChannelMonitorDailyRollup(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("monitor_id=") + builder.WriteString(fmt.Sprintf("%v", _m.MonitorID)) + builder.WriteString(", ") + builder.WriteString("model=") + builder.WriteString(_m.Model) + builder.WriteString(", ") + builder.WriteString("bucket_date=") + builder.WriteString(_m.BucketDate.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("total_checks=") + builder.WriteString(fmt.Sprintf("%v", _m.TotalChecks)) + builder.WriteString(", ") + builder.WriteString("ok_count=") + builder.WriteString(fmt.Sprintf("%v", _m.OkCount)) + builder.WriteString(", ") + builder.WriteString("operational_count=") + builder.WriteString(fmt.Sprintf("%v", _m.OperationalCount)) + builder.WriteString(", ") + builder.WriteString("degraded_count=") + builder.WriteString(fmt.Sprintf("%v", _m.DegradedCount)) + builder.WriteString(", ") + builder.WriteString("failed_count=") + builder.WriteString(fmt.Sprintf("%v", _m.FailedCount)) + builder.WriteString(", ") + builder.WriteString("error_count=") + builder.WriteString(fmt.Sprintf("%v", _m.ErrorCount)) + builder.WriteString(", ") + builder.WriteString("sum_latency_ms=") + builder.WriteString(fmt.Sprintf("%v", _m.SumLatencyMs)) + builder.WriteString(", ") + builder.WriteString("count_latency=") + builder.WriteString(fmt.Sprintf("%v", _m.CountLatency)) + builder.WriteString(", ") + builder.WriteString("sum_ping_latency_ms=") + builder.WriteString(fmt.Sprintf("%v", _m.SumPingLatencyMs)) + builder.WriteString(", ") + builder.WriteString("count_ping_latency=") + builder.WriteString(fmt.Sprintf("%v", _m.CountPingLatency)) + builder.WriteString(", ") + builder.WriteString("computed_at=") + builder.WriteString(_m.ComputedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// ChannelMonitorDailyRollups is a parsable slice of ChannelMonitorDailyRollup. +type ChannelMonitorDailyRollups []*ChannelMonitorDailyRollup diff --git a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go new file mode 100644 index 0000000000000000000000000000000000000000..e7cb9307e6f02d897808d441531e44f3a77eade6 --- /dev/null +++ b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go @@ -0,0 +1,206 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitordailyrollup + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the channelmonitordailyrollup type in the database. + Label = "channel_monitor_daily_rollup" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldMonitorID holds the string denoting the monitor_id field in the database. + FieldMonitorID = "monitor_id" + // FieldModel holds the string denoting the model field in the database. + FieldModel = "model" + // FieldBucketDate holds the string denoting the bucket_date field in the database. + FieldBucketDate = "bucket_date" + // FieldTotalChecks holds the string denoting the total_checks field in the database. + FieldTotalChecks = "total_checks" + // FieldOkCount holds the string denoting the ok_count field in the database. + FieldOkCount = "ok_count" + // FieldOperationalCount holds the string denoting the operational_count field in the database. + FieldOperationalCount = "operational_count" + // FieldDegradedCount holds the string denoting the degraded_count field in the database. + FieldDegradedCount = "degraded_count" + // FieldFailedCount holds the string denoting the failed_count field in the database. + FieldFailedCount = "failed_count" + // FieldErrorCount holds the string denoting the error_count field in the database. + FieldErrorCount = "error_count" + // FieldSumLatencyMs holds the string denoting the sum_latency_ms field in the database. + FieldSumLatencyMs = "sum_latency_ms" + // FieldCountLatency holds the string denoting the count_latency field in the database. + FieldCountLatency = "count_latency" + // FieldSumPingLatencyMs holds the string denoting the sum_ping_latency_ms field in the database. + FieldSumPingLatencyMs = "sum_ping_latency_ms" + // FieldCountPingLatency holds the string denoting the count_ping_latency field in the database. + FieldCountPingLatency = "count_ping_latency" + // FieldComputedAt holds the string denoting the computed_at field in the database. + FieldComputedAt = "computed_at" + // EdgeMonitor holds the string denoting the monitor edge name in mutations. + EdgeMonitor = "monitor" + // Table holds the table name of the channelmonitordailyrollup in the database. + Table = "channel_monitor_daily_rollups" + // MonitorTable is the table that holds the monitor relation/edge. + MonitorTable = "channel_monitor_daily_rollups" + // MonitorInverseTable is the table name for the ChannelMonitor entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitor" package. + MonitorInverseTable = "channel_monitors" + // MonitorColumn is the table column denoting the monitor relation/edge. + MonitorColumn = "monitor_id" +) + +// Columns holds all SQL columns for channelmonitordailyrollup fields. +var Columns = []string{ + FieldID, + FieldMonitorID, + FieldModel, + FieldBucketDate, + FieldTotalChecks, + FieldOkCount, + FieldOperationalCount, + FieldDegradedCount, + FieldFailedCount, + FieldErrorCount, + FieldSumLatencyMs, + FieldCountLatency, + FieldSumPingLatencyMs, + FieldCountPingLatency, + FieldComputedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // ModelValidator is a validator for the "model" field. It is called by the builders before save. + ModelValidator func(string) error + // DefaultTotalChecks holds the default value on creation for the "total_checks" field. + DefaultTotalChecks int + // DefaultOkCount holds the default value on creation for the "ok_count" field. + DefaultOkCount int + // DefaultOperationalCount holds the default value on creation for the "operational_count" field. + DefaultOperationalCount int + // DefaultDegradedCount holds the default value on creation for the "degraded_count" field. + DefaultDegradedCount int + // DefaultFailedCount holds the default value on creation for the "failed_count" field. + DefaultFailedCount int + // DefaultErrorCount holds the default value on creation for the "error_count" field. + DefaultErrorCount int + // DefaultSumLatencyMs holds the default value on creation for the "sum_latency_ms" field. + DefaultSumLatencyMs int64 + // DefaultCountLatency holds the default value on creation for the "count_latency" field. + DefaultCountLatency int + // DefaultSumPingLatencyMs holds the default value on creation for the "sum_ping_latency_ms" field. + DefaultSumPingLatencyMs int64 + // DefaultCountPingLatency holds the default value on creation for the "count_ping_latency" field. + DefaultCountPingLatency int + // DefaultComputedAt holds the default value on creation for the "computed_at" field. + DefaultComputedAt func() time.Time + // UpdateDefaultComputedAt holds the default value on update for the "computed_at" field. + UpdateDefaultComputedAt func() time.Time +) + +// OrderOption defines the ordering options for the ChannelMonitorDailyRollup queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByMonitorID orders the results by the monitor_id field. +func ByMonitorID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonitorID, opts...).ToFunc() +} + +// ByModel orders the results by the model field. +func ByModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModel, opts...).ToFunc() +} + +// ByBucketDate orders the results by the bucket_date field. +func ByBucketDate(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBucketDate, opts...).ToFunc() +} + +// ByTotalChecks orders the results by the total_checks field. +func ByTotalChecks(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalChecks, opts...).ToFunc() +} + +// ByOkCount orders the results by the ok_count field. +func ByOkCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOkCount, opts...).ToFunc() +} + +// ByOperationalCount orders the results by the operational_count field. +func ByOperationalCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOperationalCount, opts...).ToFunc() +} + +// ByDegradedCount orders the results by the degraded_count field. +func ByDegradedCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDegradedCount, opts...).ToFunc() +} + +// ByFailedCount orders the results by the failed_count field. +func ByFailedCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFailedCount, opts...).ToFunc() +} + +// ByErrorCount orders the results by the error_count field. +func ByErrorCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorCount, opts...).ToFunc() +} + +// BySumLatencyMs orders the results by the sum_latency_ms field. +func BySumLatencyMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSumLatencyMs, opts...).ToFunc() +} + +// ByCountLatency orders the results by the count_latency field. +func ByCountLatency(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCountLatency, opts...).ToFunc() +} + +// BySumPingLatencyMs orders the results by the sum_ping_latency_ms field. +func BySumPingLatencyMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSumPingLatencyMs, opts...).ToFunc() +} + +// ByCountPingLatency orders the results by the count_ping_latency field. +func ByCountPingLatency(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCountPingLatency, opts...).ToFunc() +} + +// ByComputedAt orders the results by the computed_at field. +func ByComputedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldComputedAt, opts...).ToFunc() +} + +// ByMonitorField orders the results by monitor field. +func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...)) + } +} +func newMonitorStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MonitorInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn), + ) +} diff --git a/backend/ent/channelmonitordailyrollup/where.go b/backend/ent/channelmonitordailyrollup/where.go new file mode 100644 index 0000000000000000000000000000000000000000..424c957ec9d65693dae21e9e51d6a0da68a3e313 --- /dev/null +++ b/backend/ent/channelmonitordailyrollup/where.go @@ -0,0 +1,729 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitordailyrollup + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldID, id)) +} + +// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ. +func MonitorID(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v)) +} + +// Model applies equality check predicate on the "model" field. It's identical to ModelEQ. +func Model(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v)) +} + +// BucketDate applies equality check predicate on the "bucket_date" field. It's identical to BucketDateEQ. +func BucketDate(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v)) +} + +// TotalChecks applies equality check predicate on the "total_checks" field. It's identical to TotalChecksEQ. +func TotalChecks(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v)) +} + +// OkCount applies equality check predicate on the "ok_count" field. It's identical to OkCountEQ. +func OkCount(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v)) +} + +// OperationalCount applies equality check predicate on the "operational_count" field. It's identical to OperationalCountEQ. +func OperationalCount(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v)) +} + +// DegradedCount applies equality check predicate on the "degraded_count" field. It's identical to DegradedCountEQ. +func DegradedCount(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v)) +} + +// FailedCount applies equality check predicate on the "failed_count" field. It's identical to FailedCountEQ. +func FailedCount(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v)) +} + +// ErrorCount applies equality check predicate on the "error_count" field. It's identical to ErrorCountEQ. +func ErrorCount(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v)) +} + +// SumLatencyMs applies equality check predicate on the "sum_latency_ms" field. It's identical to SumLatencyMsEQ. +func SumLatencyMs(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v)) +} + +// CountLatency applies equality check predicate on the "count_latency" field. It's identical to CountLatencyEQ. +func CountLatency(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v)) +} + +// SumPingLatencyMs applies equality check predicate on the "sum_ping_latency_ms" field. It's identical to SumPingLatencyMsEQ. +func SumPingLatencyMs(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v)) +} + +// CountPingLatency applies equality check predicate on the "count_ping_latency" field. It's identical to CountPingLatencyEQ. +func CountPingLatency(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v)) +} + +// ComputedAt applies equality check predicate on the "computed_at" field. It's identical to ComputedAtEQ. +func ComputedAt(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v)) +} + +// MonitorIDEQ applies the EQ predicate on the "monitor_id" field. +func MonitorIDEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v)) +} + +// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field. +func MonitorIDNEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldMonitorID, v)) +} + +// MonitorIDIn applies the In predicate on the "monitor_id" field. +func MonitorIDIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldMonitorID, vs...)) +} + +// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field. +func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldMonitorID, vs...)) +} + +// ModelEQ applies the EQ predicate on the "model" field. +func ModelEQ(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v)) +} + +// ModelNEQ applies the NEQ predicate on the "model" field. +func ModelNEQ(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldModel, v)) +} + +// ModelIn applies the In predicate on the "model" field. +func ModelIn(vs ...string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldModel, vs...)) +} + +// ModelNotIn applies the NotIn predicate on the "model" field. +func ModelNotIn(vs ...string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldModel, vs...)) +} + +// ModelGT applies the GT predicate on the "model" field. +func ModelGT(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldModel, v)) +} + +// ModelGTE applies the GTE predicate on the "model" field. +func ModelGTE(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldModel, v)) +} + +// ModelLT applies the LT predicate on the "model" field. +func ModelLT(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldModel, v)) +} + +// ModelLTE applies the LTE predicate on the "model" field. +func ModelLTE(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldModel, v)) +} + +// ModelContains applies the Contains predicate on the "model" field. +func ModelContains(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldContains(FieldModel, v)) +} + +// ModelHasPrefix applies the HasPrefix predicate on the "model" field. +func ModelHasPrefix(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldHasPrefix(FieldModel, v)) +} + +// ModelHasSuffix applies the HasSuffix predicate on the "model" field. +func ModelHasSuffix(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldHasSuffix(FieldModel, v)) +} + +// ModelEqualFold applies the EqualFold predicate on the "model" field. +func ModelEqualFold(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEqualFold(FieldModel, v)) +} + +// ModelContainsFold applies the ContainsFold predicate on the "model" field. +func ModelContainsFold(v string) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldContainsFold(FieldModel, v)) +} + +// BucketDateEQ applies the EQ predicate on the "bucket_date" field. +func BucketDateEQ(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v)) +} + +// BucketDateNEQ applies the NEQ predicate on the "bucket_date" field. +func BucketDateNEQ(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldBucketDate, v)) +} + +// BucketDateIn applies the In predicate on the "bucket_date" field. +func BucketDateIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldBucketDate, vs...)) +} + +// BucketDateNotIn applies the NotIn predicate on the "bucket_date" field. +func BucketDateNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldBucketDate, vs...)) +} + +// BucketDateGT applies the GT predicate on the "bucket_date" field. +func BucketDateGT(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldBucketDate, v)) +} + +// BucketDateGTE applies the GTE predicate on the "bucket_date" field. +func BucketDateGTE(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldBucketDate, v)) +} + +// BucketDateLT applies the LT predicate on the "bucket_date" field. +func BucketDateLT(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldBucketDate, v)) +} + +// BucketDateLTE applies the LTE predicate on the "bucket_date" field. +func BucketDateLTE(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldBucketDate, v)) +} + +// TotalChecksEQ applies the EQ predicate on the "total_checks" field. +func TotalChecksEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v)) +} + +// TotalChecksNEQ applies the NEQ predicate on the "total_checks" field. +func TotalChecksNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldTotalChecks, v)) +} + +// TotalChecksIn applies the In predicate on the "total_checks" field. +func TotalChecksIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldTotalChecks, vs...)) +} + +// TotalChecksNotIn applies the NotIn predicate on the "total_checks" field. +func TotalChecksNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldTotalChecks, vs...)) +} + +// TotalChecksGT applies the GT predicate on the "total_checks" field. +func TotalChecksGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldTotalChecks, v)) +} + +// TotalChecksGTE applies the GTE predicate on the "total_checks" field. +func TotalChecksGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldTotalChecks, v)) +} + +// TotalChecksLT applies the LT predicate on the "total_checks" field. +func TotalChecksLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldTotalChecks, v)) +} + +// TotalChecksLTE applies the LTE predicate on the "total_checks" field. +func TotalChecksLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldTotalChecks, v)) +} + +// OkCountEQ applies the EQ predicate on the "ok_count" field. +func OkCountEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v)) +} + +// OkCountNEQ applies the NEQ predicate on the "ok_count" field. +func OkCountNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOkCount, v)) +} + +// OkCountIn applies the In predicate on the "ok_count" field. +func OkCountIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOkCount, vs...)) +} + +// OkCountNotIn applies the NotIn predicate on the "ok_count" field. +func OkCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOkCount, vs...)) +} + +// OkCountGT applies the GT predicate on the "ok_count" field. +func OkCountGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOkCount, v)) +} + +// OkCountGTE applies the GTE predicate on the "ok_count" field. +func OkCountGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOkCount, v)) +} + +// OkCountLT applies the LT predicate on the "ok_count" field. +func OkCountLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOkCount, v)) +} + +// OkCountLTE applies the LTE predicate on the "ok_count" field. +func OkCountLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOkCount, v)) +} + +// OperationalCountEQ applies the EQ predicate on the "operational_count" field. +func OperationalCountEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v)) +} + +// OperationalCountNEQ applies the NEQ predicate on the "operational_count" field. +func OperationalCountNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOperationalCount, v)) +} + +// OperationalCountIn applies the In predicate on the "operational_count" field. +func OperationalCountIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOperationalCount, vs...)) +} + +// OperationalCountNotIn applies the NotIn predicate on the "operational_count" field. +func OperationalCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOperationalCount, vs...)) +} + +// OperationalCountGT applies the GT predicate on the "operational_count" field. +func OperationalCountGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOperationalCount, v)) +} + +// OperationalCountGTE applies the GTE predicate on the "operational_count" field. +func OperationalCountGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOperationalCount, v)) +} + +// OperationalCountLT applies the LT predicate on the "operational_count" field. +func OperationalCountLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOperationalCount, v)) +} + +// OperationalCountLTE applies the LTE predicate on the "operational_count" field. +func OperationalCountLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOperationalCount, v)) +} + +// DegradedCountEQ applies the EQ predicate on the "degraded_count" field. +func DegradedCountEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v)) +} + +// DegradedCountNEQ applies the NEQ predicate on the "degraded_count" field. +func DegradedCountNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDegradedCount, v)) +} + +// DegradedCountIn applies the In predicate on the "degraded_count" field. +func DegradedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDegradedCount, vs...)) +} + +// DegradedCountNotIn applies the NotIn predicate on the "degraded_count" field. +func DegradedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDegradedCount, vs...)) +} + +// DegradedCountGT applies the GT predicate on the "degraded_count" field. +func DegradedCountGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDegradedCount, v)) +} + +// DegradedCountGTE applies the GTE predicate on the "degraded_count" field. +func DegradedCountGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDegradedCount, v)) +} + +// DegradedCountLT applies the LT predicate on the "degraded_count" field. +func DegradedCountLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDegradedCount, v)) +} + +// DegradedCountLTE applies the LTE predicate on the "degraded_count" field. +func DegradedCountLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDegradedCount, v)) +} + +// FailedCountEQ applies the EQ predicate on the "failed_count" field. +func FailedCountEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v)) +} + +// FailedCountNEQ applies the NEQ predicate on the "failed_count" field. +func FailedCountNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldFailedCount, v)) +} + +// FailedCountIn applies the In predicate on the "failed_count" field. +func FailedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldFailedCount, vs...)) +} + +// FailedCountNotIn applies the NotIn predicate on the "failed_count" field. +func FailedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldFailedCount, vs...)) +} + +// FailedCountGT applies the GT predicate on the "failed_count" field. +func FailedCountGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldFailedCount, v)) +} + +// FailedCountGTE applies the GTE predicate on the "failed_count" field. +func FailedCountGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldFailedCount, v)) +} + +// FailedCountLT applies the LT predicate on the "failed_count" field. +func FailedCountLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldFailedCount, v)) +} + +// FailedCountLTE applies the LTE predicate on the "failed_count" field. +func FailedCountLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldFailedCount, v)) +} + +// ErrorCountEQ applies the EQ predicate on the "error_count" field. +func ErrorCountEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v)) +} + +// ErrorCountNEQ applies the NEQ predicate on the "error_count" field. +func ErrorCountNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldErrorCount, v)) +} + +// ErrorCountIn applies the In predicate on the "error_count" field. +func ErrorCountIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldErrorCount, vs...)) +} + +// ErrorCountNotIn applies the NotIn predicate on the "error_count" field. +func ErrorCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldErrorCount, vs...)) +} + +// ErrorCountGT applies the GT predicate on the "error_count" field. +func ErrorCountGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldErrorCount, v)) +} + +// ErrorCountGTE applies the GTE predicate on the "error_count" field. +func ErrorCountGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldErrorCount, v)) +} + +// ErrorCountLT applies the LT predicate on the "error_count" field. +func ErrorCountLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldErrorCount, v)) +} + +// ErrorCountLTE applies the LTE predicate on the "error_count" field. +func ErrorCountLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldErrorCount, v)) +} + +// SumLatencyMsEQ applies the EQ predicate on the "sum_latency_ms" field. +func SumLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v)) +} + +// SumLatencyMsNEQ applies the NEQ predicate on the "sum_latency_ms" field. +func SumLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumLatencyMs, v)) +} + +// SumLatencyMsIn applies the In predicate on the "sum_latency_ms" field. +func SumLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumLatencyMs, vs...)) +} + +// SumLatencyMsNotIn applies the NotIn predicate on the "sum_latency_ms" field. +func SumLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumLatencyMs, vs...)) +} + +// SumLatencyMsGT applies the GT predicate on the "sum_latency_ms" field. +func SumLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumLatencyMs, v)) +} + +// SumLatencyMsGTE applies the GTE predicate on the "sum_latency_ms" field. +func SumLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumLatencyMs, v)) +} + +// SumLatencyMsLT applies the LT predicate on the "sum_latency_ms" field. +func SumLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumLatencyMs, v)) +} + +// SumLatencyMsLTE applies the LTE predicate on the "sum_latency_ms" field. +func SumLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumLatencyMs, v)) +} + +// CountLatencyEQ applies the EQ predicate on the "count_latency" field. +func CountLatencyEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v)) +} + +// CountLatencyNEQ applies the NEQ predicate on the "count_latency" field. +func CountLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountLatency, v)) +} + +// CountLatencyIn applies the In predicate on the "count_latency" field. +func CountLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountLatency, vs...)) +} + +// CountLatencyNotIn applies the NotIn predicate on the "count_latency" field. +func CountLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountLatency, vs...)) +} + +// CountLatencyGT applies the GT predicate on the "count_latency" field. +func CountLatencyGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountLatency, v)) +} + +// CountLatencyGTE applies the GTE predicate on the "count_latency" field. +func CountLatencyGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountLatency, v)) +} + +// CountLatencyLT applies the LT predicate on the "count_latency" field. +func CountLatencyLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountLatency, v)) +} + +// CountLatencyLTE applies the LTE predicate on the "count_latency" field. +func CountLatencyLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountLatency, v)) +} + +// SumPingLatencyMsEQ applies the EQ predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v)) +} + +// SumPingLatencyMsNEQ applies the NEQ predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumPingLatencyMs, v)) +} + +// SumPingLatencyMsIn applies the In predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumPingLatencyMs, vs...)) +} + +// SumPingLatencyMsNotIn applies the NotIn predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumPingLatencyMs, vs...)) +} + +// SumPingLatencyMsGT applies the GT predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumPingLatencyMs, v)) +} + +// SumPingLatencyMsGTE applies the GTE predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumPingLatencyMs, v)) +} + +// SumPingLatencyMsLT applies the LT predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumPingLatencyMs, v)) +} + +// SumPingLatencyMsLTE applies the LTE predicate on the "sum_ping_latency_ms" field. +func SumPingLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumPingLatencyMs, v)) +} + +// CountPingLatencyEQ applies the EQ predicate on the "count_ping_latency" field. +func CountPingLatencyEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v)) +} + +// CountPingLatencyNEQ applies the NEQ predicate on the "count_ping_latency" field. +func CountPingLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountPingLatency, v)) +} + +// CountPingLatencyIn applies the In predicate on the "count_ping_latency" field. +func CountPingLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountPingLatency, vs...)) +} + +// CountPingLatencyNotIn applies the NotIn predicate on the "count_ping_latency" field. +func CountPingLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountPingLatency, vs...)) +} + +// CountPingLatencyGT applies the GT predicate on the "count_ping_latency" field. +func CountPingLatencyGT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountPingLatency, v)) +} + +// CountPingLatencyGTE applies the GTE predicate on the "count_ping_latency" field. +func CountPingLatencyGTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountPingLatency, v)) +} + +// CountPingLatencyLT applies the LT predicate on the "count_ping_latency" field. +func CountPingLatencyLT(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountPingLatency, v)) +} + +// CountPingLatencyLTE applies the LTE predicate on the "count_ping_latency" field. +func CountPingLatencyLTE(v int) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountPingLatency, v)) +} + +// ComputedAtEQ applies the EQ predicate on the "computed_at" field. +func ComputedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v)) +} + +// ComputedAtNEQ applies the NEQ predicate on the "computed_at" field. +func ComputedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldComputedAt, v)) +} + +// ComputedAtIn applies the In predicate on the "computed_at" field. +func ComputedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldComputedAt, vs...)) +} + +// ComputedAtNotIn applies the NotIn predicate on the "computed_at" field. +func ComputedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldComputedAt, vs...)) +} + +// ComputedAtGT applies the GT predicate on the "computed_at" field. +func ComputedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldComputedAt, v)) +} + +// ComputedAtGTE applies the GTE predicate on the "computed_at" field. +func ComputedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldComputedAt, v)) +} + +// ComputedAtLT applies the LT predicate on the "computed_at" field. +func ComputedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldComputedAt, v)) +} + +// ComputedAtLTE applies the LTE predicate on the "computed_at" field. +func ComputedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldComputedAt, v)) +} + +// HasMonitor applies the HasEdge predicate on the "monitor" edge. +func HasMonitor() predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates). +func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) { + step := newMonitorStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup { + return predicate.ChannelMonitorDailyRollup(sql.NotPredicates(p)) +} diff --git a/backend/ent/channelmonitordailyrollup_create.go b/backend/ent/channelmonitordailyrollup_create.go new file mode 100644 index 0000000000000000000000000000000000000000..5f8754babd59a69337faec6b25a99d91497b1f00 --- /dev/null +++ b/backend/ent/channelmonitordailyrollup_create.go @@ -0,0 +1,1509 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" +) + +// ChannelMonitorDailyRollupCreate is the builder for creating a ChannelMonitorDailyRollup entity. +type ChannelMonitorDailyRollupCreate struct { + config + mutation *ChannelMonitorDailyRollupMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetMonitorID sets the "monitor_id" field. +func (_c *ChannelMonitorDailyRollupCreate) SetMonitorID(v int64) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetMonitorID(v) + return _c +} + +// SetModel sets the "model" field. +func (_c *ChannelMonitorDailyRollupCreate) SetModel(v string) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetModel(v) + return _c +} + +// SetBucketDate sets the "bucket_date" field. +func (_c *ChannelMonitorDailyRollupCreate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetBucketDate(v) + return _c +} + +// SetTotalChecks sets the "total_checks" field. +func (_c *ChannelMonitorDailyRollupCreate) SetTotalChecks(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetTotalChecks(v) + return _c +} + +// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetTotalChecks(*v) + } + return _c +} + +// SetOkCount sets the "ok_count" field. +func (_c *ChannelMonitorDailyRollupCreate) SetOkCount(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetOkCount(v) + return _c +} + +// SetNillableOkCount sets the "ok_count" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetOkCount(*v) + } + return _c +} + +// SetOperationalCount sets the "operational_count" field. +func (_c *ChannelMonitorDailyRollupCreate) SetOperationalCount(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetOperationalCount(v) + return _c +} + +// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetOperationalCount(*v) + } + return _c +} + +// SetDegradedCount sets the "degraded_count" field. +func (_c *ChannelMonitorDailyRollupCreate) SetDegradedCount(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetDegradedCount(v) + return _c +} + +// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetDegradedCount(*v) + } + return _c +} + +// SetFailedCount sets the "failed_count" field. +func (_c *ChannelMonitorDailyRollupCreate) SetFailedCount(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetFailedCount(v) + return _c +} + +// SetNillableFailedCount sets the "failed_count" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetFailedCount(*v) + } + return _c +} + +// SetErrorCount sets the "error_count" field. +func (_c *ChannelMonitorDailyRollupCreate) SetErrorCount(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetErrorCount(v) + return _c +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetErrorCount(*v) + } + return _c +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (_c *ChannelMonitorDailyRollupCreate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetSumLatencyMs(v) + return _c +} + +// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetSumLatencyMs(*v) + } + return _c +} + +// SetCountLatency sets the "count_latency" field. +func (_c *ChannelMonitorDailyRollupCreate) SetCountLatency(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetCountLatency(v) + return _c +} + +// SetNillableCountLatency sets the "count_latency" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetCountLatency(*v) + } + return _c +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (_c *ChannelMonitorDailyRollupCreate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetSumPingLatencyMs(v) + return _c +} + +// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetSumPingLatencyMs(*v) + } + return _c +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (_c *ChannelMonitorDailyRollupCreate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetCountPingLatency(v) + return _c +} + +// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetCountPingLatency(*v) + } + return _c +} + +// SetComputedAt sets the "computed_at" field. +func (_c *ChannelMonitorDailyRollupCreate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupCreate { + _c.mutation.SetComputedAt(v) + return _c +} + +// SetNillableComputedAt sets the "computed_at" field if the given value is not nil. +func (_c *ChannelMonitorDailyRollupCreate) SetNillableComputedAt(v *time.Time) *ChannelMonitorDailyRollupCreate { + if v != nil { + _c.SetComputedAt(*v) + } + return _c +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_c *ChannelMonitorDailyRollupCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupCreate { + return _c.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder. +func (_c *ChannelMonitorDailyRollupCreate) Mutation() *ChannelMonitorDailyRollupMutation { + return _c.mutation +} + +// Save creates the ChannelMonitorDailyRollup in the database. +func (_c *ChannelMonitorDailyRollupCreate) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ChannelMonitorDailyRollupCreate) SaveX(ctx context.Context) *ChannelMonitorDailyRollup { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorDailyRollupCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorDailyRollupCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ChannelMonitorDailyRollupCreate) defaults() { + if _, ok := _c.mutation.TotalChecks(); !ok { + v := channelmonitordailyrollup.DefaultTotalChecks + _c.mutation.SetTotalChecks(v) + } + if _, ok := _c.mutation.OkCount(); !ok { + v := channelmonitordailyrollup.DefaultOkCount + _c.mutation.SetOkCount(v) + } + if _, ok := _c.mutation.OperationalCount(); !ok { + v := channelmonitordailyrollup.DefaultOperationalCount + _c.mutation.SetOperationalCount(v) + } + if _, ok := _c.mutation.DegradedCount(); !ok { + v := channelmonitordailyrollup.DefaultDegradedCount + _c.mutation.SetDegradedCount(v) + } + if _, ok := _c.mutation.FailedCount(); !ok { + v := channelmonitordailyrollup.DefaultFailedCount + _c.mutation.SetFailedCount(v) + } + if _, ok := _c.mutation.ErrorCount(); !ok { + v := channelmonitordailyrollup.DefaultErrorCount + _c.mutation.SetErrorCount(v) + } + if _, ok := _c.mutation.SumLatencyMs(); !ok { + v := channelmonitordailyrollup.DefaultSumLatencyMs + _c.mutation.SetSumLatencyMs(v) + } + if _, ok := _c.mutation.CountLatency(); !ok { + v := channelmonitordailyrollup.DefaultCountLatency + _c.mutation.SetCountLatency(v) + } + if _, ok := _c.mutation.SumPingLatencyMs(); !ok { + v := channelmonitordailyrollup.DefaultSumPingLatencyMs + _c.mutation.SetSumPingLatencyMs(v) + } + if _, ok := _c.mutation.CountPingLatency(); !ok { + v := channelmonitordailyrollup.DefaultCountPingLatency + _c.mutation.SetCountPingLatency(v) + } + if _, ok := _c.mutation.ComputedAt(); !ok { + v := channelmonitordailyrollup.DefaultComputedAt() + _c.mutation.SetComputedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ChannelMonitorDailyRollupCreate) check() error { + if _, ok := _c.mutation.MonitorID(); !ok { + return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.monitor_id"`)} + } + if _, ok := _c.mutation.Model(); !ok { + return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.model"`)} + } + if v, ok := _c.mutation.Model(); ok { + if err := channelmonitordailyrollup.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)} + } + } + if _, ok := _c.mutation.BucketDate(); !ok { + return &ValidationError{Name: "bucket_date", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.bucket_date"`)} + } + if _, ok := _c.mutation.TotalChecks(); !ok { + return &ValidationError{Name: "total_checks", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.total_checks"`)} + } + if _, ok := _c.mutation.OkCount(); !ok { + return &ValidationError{Name: "ok_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.ok_count"`)} + } + if _, ok := _c.mutation.OperationalCount(); !ok { + return &ValidationError{Name: "operational_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.operational_count"`)} + } + if _, ok := _c.mutation.DegradedCount(); !ok { + return &ValidationError{Name: "degraded_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.degraded_count"`)} + } + if _, ok := _c.mutation.FailedCount(); !ok { + return &ValidationError{Name: "failed_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.failed_count"`)} + } + if _, ok := _c.mutation.ErrorCount(); !ok { + return &ValidationError{Name: "error_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.error_count"`)} + } + if _, ok := _c.mutation.SumLatencyMs(); !ok { + return &ValidationError{Name: "sum_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_latency_ms"`)} + } + if _, ok := _c.mutation.CountLatency(); !ok { + return &ValidationError{Name: "count_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_latency"`)} + } + if _, ok := _c.mutation.SumPingLatencyMs(); !ok { + return &ValidationError{Name: "sum_ping_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_ping_latency_ms"`)} + } + if _, ok := _c.mutation.CountPingLatency(); !ok { + return &ValidationError{Name: "count_ping_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_ping_latency"`)} + } + if _, ok := _c.mutation.ComputedAt(); !ok { + return &ValidationError{Name: "computed_at", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.computed_at"`)} + } + if len(_c.mutation.MonitorIDs()) == 0 { + return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorDailyRollup.monitor"`)} + } + return nil +} + +func (_c *ChannelMonitorDailyRollupCreate) sqlSave(ctx context.Context) (*ChannelMonitorDailyRollup, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRollup, *sqlgraph.CreateSpec) { + var ( + _node = &ChannelMonitorDailyRollup{config: _c.config} + _spec = sqlgraph.NewCreateSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.Model(); ok { + _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value) + _node.Model = value + } + if value, ok := _c.mutation.BucketDate(); ok { + _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value) + _node.BucketDate = value + } + if value, ok := _c.mutation.TotalChecks(); ok { + _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value) + _node.TotalChecks = value + } + if value, ok := _c.mutation.OkCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value) + _node.OkCount = value + } + if value, ok := _c.mutation.OperationalCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value) + _node.OperationalCount = value + } + if value, ok := _c.mutation.DegradedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value) + _node.DegradedCount = value + } + if value, ok := _c.mutation.FailedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value) + _node.FailedCount = value + } + if value, ok := _c.mutation.ErrorCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value) + _node.ErrorCount = value + } + if value, ok := _c.mutation.SumLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value) + _node.SumLatencyMs = value + } + if value, ok := _c.mutation.CountLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value) + _node.CountLatency = value + } + if value, ok := _c.mutation.SumPingLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value) + _node.SumPingLatencyMs = value + } + if value, ok := _c.mutation.CountPingLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value) + _node.CountPingLatency = value + } + if value, ok := _c.mutation.ComputedAt(); ok { + _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value) + _node.ComputedAt = value + } + if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitordailyrollup.MonitorTable, + Columns: []string{channelmonitordailyrollup.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.MonitorID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorDailyRollup.Create(). +// SetMonitorID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) { +// SetMonitorID(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorDailyRollupCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertOne { + _c.conflict = opts + return &ChannelMonitorDailyRollupUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorDailyRollupCreate) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorDailyRollupUpsertOne{ + create: _c, + } +} + +type ( + // ChannelMonitorDailyRollupUpsertOne is the builder for "upsert"-ing + // one ChannelMonitorDailyRollup node. + ChannelMonitorDailyRollupUpsertOne struct { + create *ChannelMonitorDailyRollupCreate + } + + // ChannelMonitorDailyRollupUpsert is the "OnConflict" setter. + ChannelMonitorDailyRollupUpsert struct { + *sql.UpdateSet + } +) + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorDailyRollupUpsert) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldMonitorID, v) + return u +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateMonitorID() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldMonitorID) + return u +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorDailyRollupUpsert) SetModel(v string) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldModel, v) + return u +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateModel() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldModel) + return u +} + +// SetBucketDate sets the "bucket_date" field. +func (u *ChannelMonitorDailyRollupUpsert) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldBucketDate, v) + return u +} + +// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateBucketDate() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldBucketDate) + return u +} + +// SetTotalChecks sets the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsert) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldTotalChecks, v) + return u +} + +// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldTotalChecks) + return u +} + +// AddTotalChecks adds v to the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsert) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldTotalChecks, v) + return u +} + +// SetOkCount sets the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsert) SetOkCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldOkCount, v) + return u +} + +// UpdateOkCount sets the "ok_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateOkCount() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldOkCount) + return u +} + +// AddOkCount adds v to the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsert) AddOkCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldOkCount, v) + return u +} + +// SetOperationalCount sets the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsert) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldOperationalCount, v) + return u +} + +// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldOperationalCount) + return u +} + +// AddOperationalCount adds v to the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsert) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldOperationalCount, v) + return u +} + +// SetDegradedCount sets the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsert) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldDegradedCount, v) + return u +} + +// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldDegradedCount) + return u +} + +// AddDegradedCount adds v to the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsert) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldDegradedCount, v) + return u +} + +// SetFailedCount sets the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsert) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldFailedCount, v) + return u +} + +// UpdateFailedCount sets the "failed_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateFailedCount() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldFailedCount) + return u +} + +// AddFailedCount adds v to the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsert) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldFailedCount, v) + return u +} + +// SetErrorCount sets the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsert) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldErrorCount, v) + return u +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateErrorCount() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldErrorCount) + return u +} + +// AddErrorCount adds v to the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsert) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldErrorCount, v) + return u +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsert) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldSumLatencyMs, v) + return u +} + +// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldSumLatencyMs) + return u +} + +// AddSumLatencyMs adds v to the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsert) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldSumLatencyMs, v) + return u +} + +// SetCountLatency sets the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsert) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldCountLatency, v) + return u +} + +// UpdateCountLatency sets the "count_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateCountLatency() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldCountLatency) + return u +} + +// AddCountLatency adds v to the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsert) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldCountLatency, v) + return u +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsert) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldSumPingLatencyMs, v) + return u +} + +// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldSumPingLatencyMs) + return u +} + +// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsert) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldSumPingLatencyMs, v) + return u +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsert) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldCountPingLatency, v) + return u +} + +// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldCountPingLatency) + return u +} + +// AddCountPingLatency adds v to the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsert) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert { + u.Add(channelmonitordailyrollup.FieldCountPingLatency, v) + return u +} + +// SetComputedAt sets the "computed_at" field. +func (u *ChannelMonitorDailyRollupUpsert) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsert { + u.Set(channelmonitordailyrollup.FieldComputedAt, v) + return u +} + +// UpdateComputedAt sets the "computed_at" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsert) UpdateComputedAt() *ChannelMonitorDailyRollupUpsert { + u.SetExcluded(channelmonitordailyrollup.FieldComputedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateNewValues() *ChannelMonitorDailyRollupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorDailyRollupUpsertOne) Ignore() *ChannelMonitorDailyRollupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorDailyRollupUpsertOne) DoNothing() *ChannelMonitorDailyRollupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreate.OnConflict +// documentation for more info. +func (u *ChannelMonitorDailyRollupUpsertOne) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update}) + })) + return u +} + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetMonitorID(v) + }) +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateMonitorID() + }) +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetModel(v string) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateModel() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateModel() + }) +} + +// SetBucketDate sets the "bucket_date" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetBucketDate(v) + }) +} + +// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateBucketDate() + }) +} + +// SetTotalChecks sets the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetTotalChecks(v) + }) +} + +// AddTotalChecks adds v to the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddTotalChecks(v) + }) +} + +// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateTotalChecks() + }) +} + +// SetOkCount sets the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetOkCount(v) + }) +} + +// AddOkCount adds v to the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddOkCount(v) + }) +} + +// UpdateOkCount sets the "ok_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOkCount() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateOkCount() + }) +} + +// SetOperationalCount sets the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetOperationalCount(v) + }) +} + +// AddOperationalCount adds v to the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddOperationalCount(v) + }) +} + +// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateOperationalCount() + }) +} + +// SetDegradedCount sets the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetDegradedCount(v) + }) +} + +// AddDegradedCount adds v to the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddDegradedCount(v) + }) +} + +// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateDegradedCount() + }) +} + +// SetFailedCount sets the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetFailedCount(v) + }) +} + +// AddFailedCount adds v to the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddFailedCount(v) + }) +} + +// UpdateFailedCount sets the "failed_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateFailedCount() + }) +} + +// SetErrorCount sets the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetErrorCount(v) + }) +} + +// AddErrorCount adds v to the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddErrorCount(v) + }) +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateErrorCount() + }) +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetSumLatencyMs(v) + }) +} + +// AddSumLatencyMs adds v to the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddSumLatencyMs(v) + }) +} + +// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateSumLatencyMs() + }) +} + +// SetCountLatency sets the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetCountLatency(v) + }) +} + +// AddCountLatency adds v to the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddCountLatency(v) + }) +} + +// UpdateCountLatency sets the "count_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateCountLatency() + }) +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetSumPingLatencyMs(v) + }) +} + +// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddSumPingLatencyMs(v) + }) +} + +// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateSumPingLatencyMs() + }) +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetCountPingLatency(v) + }) +} + +// AddCountPingLatency adds v to the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsertOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddCountPingLatency(v) + }) +} + +// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateCountPingLatency() + }) +} + +// SetComputedAt sets the "computed_at" field. +func (u *ChannelMonitorDailyRollupUpsertOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetComputedAt(v) + }) +} + +// UpdateComputedAt sets the "computed_at" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertOne) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertOne { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateComputedAt() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorDailyRollupUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorDailyRollupCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorDailyRollupUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ChannelMonitorDailyRollupUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ChannelMonitorDailyRollupUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ChannelMonitorDailyRollupCreateBulk is the builder for creating many ChannelMonitorDailyRollup entities in bulk. +type ChannelMonitorDailyRollupCreateBulk struct { + config + err error + builders []*ChannelMonitorDailyRollupCreate + conflict []sql.ConflictOption +} + +// Save creates the ChannelMonitorDailyRollup entities in the database. +func (_c *ChannelMonitorDailyRollupCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ChannelMonitorDailyRollup, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ChannelMonitorDailyRollupMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ChannelMonitorDailyRollupCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorDailyRollup { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorDailyRollupCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorDailyRollupCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorDailyRollup.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) { +// SetMonitorID(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertBulk { + _c.conflict = opts + return &ChannelMonitorDailyRollupUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorDailyRollupUpsertBulk{ + create: _c, + } +} + +// ChannelMonitorDailyRollupUpsertBulk is the builder for "upsert"-ing +// a bulk of ChannelMonitorDailyRollup nodes. +type ChannelMonitorDailyRollupUpsertBulk struct { + create *ChannelMonitorDailyRollupCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateNewValues() *ChannelMonitorDailyRollupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorDailyRollup.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorDailyRollupUpsertBulk) Ignore() *ChannelMonitorDailyRollupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorDailyRollupUpsertBulk) DoNothing() *ChannelMonitorDailyRollupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreateBulk.OnConflict +// documentation for more info. +func (u *ChannelMonitorDailyRollupUpsertBulk) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update}) + })) + return u +} + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetMonitorID(v) + }) +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateMonitorID() + }) +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetModel(v string) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateModel() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateModel() + }) +} + +// SetBucketDate sets the "bucket_date" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetBucketDate(v) + }) +} + +// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateBucketDate() + }) +} + +// SetTotalChecks sets the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetTotalChecks(v) + }) +} + +// AddTotalChecks adds v to the "total_checks" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddTotalChecks(v) + }) +} + +// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateTotalChecks() + }) +} + +// SetOkCount sets the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetOkCount(v) + }) +} + +// AddOkCount adds v to the "ok_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddOkCount(v) + }) +} + +// UpdateOkCount sets the "ok_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOkCount() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateOkCount() + }) +} + +// SetOperationalCount sets the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetOperationalCount(v) + }) +} + +// AddOperationalCount adds v to the "operational_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddOperationalCount(v) + }) +} + +// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateOperationalCount() + }) +} + +// SetDegradedCount sets the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetDegradedCount(v) + }) +} + +// AddDegradedCount adds v to the "degraded_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddDegradedCount(v) + }) +} + +// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateDegradedCount() + }) +} + +// SetFailedCount sets the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetFailedCount(v) + }) +} + +// AddFailedCount adds v to the "failed_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddFailedCount(v) + }) +} + +// UpdateFailedCount sets the "failed_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateFailedCount() + }) +} + +// SetErrorCount sets the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetErrorCount(v) + }) +} + +// AddErrorCount adds v to the "error_count" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddErrorCount(v) + }) +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateErrorCount() + }) +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetSumLatencyMs(v) + }) +} + +// AddSumLatencyMs adds v to the "sum_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddSumLatencyMs(v) + }) +} + +// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateSumLatencyMs() + }) +} + +// SetCountLatency sets the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetCountLatency(v) + }) +} + +// AddCountLatency adds v to the "count_latency" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddCountLatency(v) + }) +} + +// UpdateCountLatency sets the "count_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateCountLatency() + }) +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetSumPingLatencyMs(v) + }) +} + +// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddSumPingLatencyMs(v) + }) +} + +// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateSumPingLatencyMs() + }) +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetCountPingLatency(v) + }) +} + +// AddCountPingLatency adds v to the "count_ping_latency" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.AddCountPingLatency(v) + }) +} + +// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateCountPingLatency() + }) +} + +// SetComputedAt sets the "computed_at" field. +func (u *ChannelMonitorDailyRollupUpsertBulk) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.SetComputedAt(v) + }) +} + +// UpdateComputedAt sets the "computed_at" field to the value that was provided on create. +func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertBulk { + return u.Update(func(s *ChannelMonitorDailyRollupUpsert) { + s.UpdateComputedAt() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorDailyRollupUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorDailyRollupCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorDailyRollupCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorDailyRollupUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitordailyrollup_delete.go b/backend/ent/channelmonitordailyrollup_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..460c94f85cac19fb115baed42c10e8fc8f611b6b --- /dev/null +++ b/backend/ent/channelmonitordailyrollup_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorDailyRollupDelete is the builder for deleting a ChannelMonitorDailyRollup entity. +type ChannelMonitorDailyRollupDelete struct { + config + hooks []Hook + mutation *ChannelMonitorDailyRollupMutation +} + +// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder. +func (_d *ChannelMonitorDailyRollupDelete) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ChannelMonitorDailyRollupDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorDailyRollupDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ChannelMonitorDailyRollupDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ChannelMonitorDailyRollupDeleteOne is the builder for deleting a single ChannelMonitorDailyRollup entity. +type ChannelMonitorDailyRollupDeleteOne struct { + _d *ChannelMonitorDailyRollupDelete +} + +// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder. +func (_d *ChannelMonitorDailyRollupDeleteOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ChannelMonitorDailyRollupDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{channelmonitordailyrollup.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorDailyRollupDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitordailyrollup_query.go b/backend/ent/channelmonitordailyrollup_query.go new file mode 100644 index 0000000000000000000000000000000000000000..e34afc6163ae92811f5a3450d1a1fe57546640a2 --- /dev/null +++ b/backend/ent/channelmonitordailyrollup_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorDailyRollupQuery is the builder for querying ChannelMonitorDailyRollup entities. +type ChannelMonitorDailyRollupQuery struct { + config + ctx *QueryContext + order []channelmonitordailyrollup.OrderOption + inters []Interceptor + predicates []predicate.ChannelMonitorDailyRollup + withMonitor *ChannelMonitorQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ChannelMonitorDailyRollupQuery builder. +func (_q *ChannelMonitorDailyRollupQuery) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ChannelMonitorDailyRollupQuery) Limit(limit int) *ChannelMonitorDailyRollupQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ChannelMonitorDailyRollupQuery) Offset(offset int) *ChannelMonitorDailyRollupQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ChannelMonitorDailyRollupQuery) Unique(unique bool) *ChannelMonitorDailyRollupQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ChannelMonitorDailyRollupQuery) Order(o ...channelmonitordailyrollup.OrderOption) *ChannelMonitorDailyRollupQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryMonitor chains the current query on the "monitor" edge. +func (_q *ChannelMonitorDailyRollupQuery) QueryMonitor() *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, selector), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first ChannelMonitorDailyRollup entity from the query. +// Returns a *NotFoundError when no ChannelMonitorDailyRollup was found. +func (_q *ChannelMonitorDailyRollupQuery) First(ctx context.Context) (*ChannelMonitorDailyRollup, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{channelmonitordailyrollup.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) FirstX(ctx context.Context) *ChannelMonitorDailyRollup { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ChannelMonitorDailyRollup ID from the query. +// Returns a *NotFoundError when no ChannelMonitorDailyRollup ID was found. +func (_q *ChannelMonitorDailyRollupQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{channelmonitordailyrollup.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ChannelMonitorDailyRollup entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup entity is found. +// Returns a *NotFoundError when no ChannelMonitorDailyRollup entities are found. +func (_q *ChannelMonitorDailyRollupQuery) Only(ctx context.Context) (*ChannelMonitorDailyRollup, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{channelmonitordailyrollup.Label} + default: + return nil, &NotSingularError{channelmonitordailyrollup.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) OnlyX(ctx context.Context) *ChannelMonitorDailyRollup { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ChannelMonitorDailyRollup ID in the query. +// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ChannelMonitorDailyRollupQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{channelmonitordailyrollup.Label} + default: + err = &NotSingularError{channelmonitordailyrollup.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ChannelMonitorDailyRollups. +func (_q *ChannelMonitorDailyRollupQuery) All(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ChannelMonitorDailyRollup, *ChannelMonitorDailyRollupQuery]() + return withInterceptors[[]*ChannelMonitorDailyRollup](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) AllX(ctx context.Context) []*ChannelMonitorDailyRollup { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ChannelMonitorDailyRollup IDs. +func (_q *ChannelMonitorDailyRollupQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(channelmonitordailyrollup.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ChannelMonitorDailyRollupQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorDailyRollupQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ChannelMonitorDailyRollupQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ChannelMonitorDailyRollupQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ChannelMonitorDailyRollupQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ChannelMonitorDailyRollupQuery) Clone() *ChannelMonitorDailyRollupQuery { + if _q == nil { + return nil + } + return &ChannelMonitorDailyRollupQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]channelmonitordailyrollup.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ChannelMonitorDailyRollup{}, _q.predicates...), + withMonitor: _q.withMonitor.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithMonitor tells the query-builder to eager-load the nodes that are connected to +// the "monitor" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorDailyRollupQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorDailyRollupQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withMonitor = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// MonitorID int64 `json:"monitor_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ChannelMonitorDailyRollup.Query(). +// GroupBy(channelmonitordailyrollup.FieldMonitorID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string) *ChannelMonitorDailyRollupGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ChannelMonitorDailyRollupGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = channelmonitordailyrollup.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// MonitorID int64 `json:"monitor_id,omitempty"` +// } +// +// client.ChannelMonitorDailyRollup.Query(). +// Select(channelmonitordailyrollup.FieldMonitorID). +// Scan(ctx, &v) +func (_q *ChannelMonitorDailyRollupQuery) Select(fields ...string) *ChannelMonitorDailyRollupSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ChannelMonitorDailyRollupSelect{ChannelMonitorDailyRollupQuery: _q} + sbuild.label = channelmonitordailyrollup.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ChannelMonitorDailyRollupSelect configured with the given aggregations. +func (_q *ChannelMonitorDailyRollupQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ChannelMonitorDailyRollupQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !channelmonitordailyrollup.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ChannelMonitorDailyRollupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorDailyRollup, error) { + var ( + nodes = []*ChannelMonitorDailyRollup{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withMonitor != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ChannelMonitorDailyRollup).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ChannelMonitorDailyRollup{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withMonitor; query != nil { + if err := _q.loadMonitor(ctx, query, nodes, nil, + func(n *ChannelMonitorDailyRollup, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *ChannelMonitorDailyRollupQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorDailyRollup, init func(*ChannelMonitorDailyRollup), assign func(*ChannelMonitorDailyRollup, *ChannelMonitor)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*ChannelMonitorDailyRollup) + for i := range nodes { + fk := nodes[i].MonitorID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(channelmonitor.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *ChannelMonitorDailyRollupQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ChannelMonitorDailyRollupQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID) + for i := range fields { + if fields[i] != channelmonitordailyrollup.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withMonitor != nil { + _spec.Node.AddColumnOnce(channelmonitordailyrollup.FieldMonitorID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ChannelMonitorDailyRollupQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(channelmonitordailyrollup.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = channelmonitordailyrollup.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ChannelMonitorDailyRollupQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ChannelMonitorDailyRollupQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ChannelMonitorDailyRollupGroupBy is the group-by builder for ChannelMonitorDailyRollup entities. +type ChannelMonitorDailyRollupGroupBy struct { + selector + build *ChannelMonitorDailyRollupQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ChannelMonitorDailyRollupGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ChannelMonitorDailyRollupGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ChannelMonitorDailyRollupGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ChannelMonitorDailyRollupSelect is the builder for selecting fields of ChannelMonitorDailyRollup entities. +type ChannelMonitorDailyRollupSelect struct { + *ChannelMonitorDailyRollupQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ChannelMonitorDailyRollupSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ChannelMonitorDailyRollupSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupSelect](ctx, _s.ChannelMonitorDailyRollupQuery, _s, _s.inters, v) +} + +func (_s *ChannelMonitorDailyRollupSelect) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/channelmonitordailyrollup_update.go b/backend/ent/channelmonitordailyrollup_update.go new file mode 100644 index 0000000000000000000000000000000000000000..02cd86c5c7c4f3f76d8deb0d0b73550e931229a8 --- /dev/null +++ b/backend/ent/channelmonitordailyrollup_update.go @@ -0,0 +1,961 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorDailyRollupUpdate is the builder for updating ChannelMonitorDailyRollup entities. +type ChannelMonitorDailyRollupUpdate struct { + config + hooks []Hook + mutation *ChannelMonitorDailyRollupMutation +} + +// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder. +func (_u *ChannelMonitorDailyRollupUpdate) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetMonitorID sets the "monitor_id" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdate { + _u.mutation.SetMonitorID(v) + return _u +} + +// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetMonitorID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetModel(v string) *ChannelMonitorDailyRollupUpdate { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetBucketDate sets the "bucket_date" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdate { + _u.mutation.SetBucketDate(v) + return _u +} + +// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetBucketDate(*v) + } + return _u +} + +// SetTotalChecks sets the "total_checks" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetTotalChecks() + _u.mutation.SetTotalChecks(v) + return _u +} + +// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetTotalChecks(*v) + } + return _u +} + +// AddTotalChecks adds value to the "total_checks" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddTotalChecks(v) + return _u +} + +// SetOkCount sets the "ok_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetOkCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetOkCount() + _u.mutation.SetOkCount(v) + return _u +} + +// SetNillableOkCount sets the "ok_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetOkCount(*v) + } + return _u +} + +// AddOkCount adds value to the "ok_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddOkCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddOkCount(v) + return _u +} + +// SetOperationalCount sets the "operational_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetOperationalCount() + _u.mutation.SetOperationalCount(v) + return _u +} + +// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetOperationalCount(*v) + } + return _u +} + +// AddOperationalCount adds value to the "operational_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddOperationalCount(v) + return _u +} + +// SetDegradedCount sets the "degraded_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetDegradedCount() + _u.mutation.SetDegradedCount(v) + return _u +} + +// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetDegradedCount(*v) + } + return _u +} + +// AddDegradedCount adds value to the "degraded_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddDegradedCount(v) + return _u +} + +// SetFailedCount sets the "failed_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetFailedCount() + _u.mutation.SetFailedCount(v) + return _u +} + +// SetNillableFailedCount sets the "failed_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetFailedCount(*v) + } + return _u +} + +// AddFailedCount adds value to the "failed_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddFailedCount(v) + return _u +} + +// SetErrorCount sets the "error_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetErrorCount() + _u.mutation.SetErrorCount(v) + return _u +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetErrorCount(*v) + } + return _u +} + +// AddErrorCount adds value to the "error_count" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddErrorCount(v) + return _u +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetSumLatencyMs() + _u.mutation.SetSumLatencyMs(v) + return _u +} + +// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetSumLatencyMs(*v) + } + return _u +} + +// AddSumLatencyMs adds value to the "sum_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddSumLatencyMs(v) + return _u +} + +// SetCountLatency sets the "count_latency" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetCountLatency() + _u.mutation.SetCountLatency(v) + return _u +} + +// SetNillableCountLatency sets the "count_latency" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetCountLatency(*v) + } + return _u +} + +// AddCountLatency adds value to the "count_latency" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddCountLatency(v) + return _u +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetSumPingLatencyMs() + _u.mutation.SetSumPingLatencyMs(v) + return _u +} + +// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetSumPingLatencyMs(*v) + } + return _u +} + +// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddSumPingLatencyMs(v) + return _u +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.ResetCountPingLatency() + _u.mutation.SetCountPingLatency(v) + return _u +} + +// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdate { + if v != nil { + _u.SetCountPingLatency(*v) + } + return _u +} + +// AddCountPingLatency adds value to the "count_ping_latency" field. +func (_u *ChannelMonitorDailyRollupUpdate) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate { + _u.mutation.AddCountPingLatency(v) + return _u +} + +// SetComputedAt sets the "computed_at" field. +func (_u *ChannelMonitorDailyRollupUpdate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdate { + _u.mutation.SetComputedAt(v) + return _u +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorDailyRollupUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdate { + return _u.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder. +func (_u *ChannelMonitorDailyRollupUpdate) Mutation() *ChannelMonitorDailyRollupMutation { + return _u.mutation +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorDailyRollupUpdate) ClearMonitor() *ChannelMonitorDailyRollupUpdate { + _u.mutation.ClearMonitor() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ChannelMonitorDailyRollupUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorDailyRollupUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ChannelMonitorDailyRollupUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorDailyRollupUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorDailyRollupUpdate) defaults() { + if _, ok := _u.mutation.ComputedAt(); !ok { + v := channelmonitordailyrollup.UpdateDefaultComputedAt() + _u.mutation.SetComputedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorDailyRollupUpdate) check() error { + if v, ok := _u.mutation.Model(); ok { + if err := channelmonitordailyrollup.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)} + } + } + if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`) + } + return nil +} + +func (_u *ChannelMonitorDailyRollupUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.BucketDate(); ok { + _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value) + } + if value, ok := _u.mutation.TotalChecks(); ok { + _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTotalChecks(); ok { + _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value) + } + if value, ok := _u.mutation.OkCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOkCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value) + } + if value, ok := _u.mutation.OperationalCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOperationalCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.DegradedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDegradedCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.FailedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFailedCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedErrorCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SumLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSumLatencyMs(); ok { + _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.CountLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCountLatency(); ok { + _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.SumPingLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok { + _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.CountPingLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCountPingLatency(); ok { + _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.ComputedAt(); ok { + _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value) + } + if _u.mutation.MonitorCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitordailyrollup.MonitorTable, + Columns: []string{channelmonitordailyrollup.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitordailyrollup.MonitorTable, + Columns: []string{channelmonitordailyrollup.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitordailyrollup.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ChannelMonitorDailyRollupUpdateOne is the builder for updating a single ChannelMonitorDailyRollup entity. +type ChannelMonitorDailyRollupUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ChannelMonitorDailyRollupMutation +} + +// SetMonitorID sets the "monitor_id" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.SetMonitorID(v) + return _u +} + +// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetMonitorID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetModel(v string) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetBucketDate sets the "bucket_date" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.SetBucketDate(v) + return _u +} + +// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetBucketDate(*v) + } + return _u +} + +// SetTotalChecks sets the "total_checks" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetTotalChecks() + _u.mutation.SetTotalChecks(v) + return _u +} + +// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetTotalChecks(*v) + } + return _u +} + +// AddTotalChecks adds value to the "total_checks" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddTotalChecks(v) + return _u +} + +// SetOkCount sets the "ok_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetOkCount() + _u.mutation.SetOkCount(v) + return _u +} + +// SetNillableOkCount sets the "ok_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetOkCount(*v) + } + return _u +} + +// AddOkCount adds value to the "ok_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddOkCount(v) + return _u +} + +// SetOperationalCount sets the "operational_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetOperationalCount() + _u.mutation.SetOperationalCount(v) + return _u +} + +// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetOperationalCount(*v) + } + return _u +} + +// AddOperationalCount adds value to the "operational_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddOperationalCount(v) + return _u +} + +// SetDegradedCount sets the "degraded_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetDegradedCount() + _u.mutation.SetDegradedCount(v) + return _u +} + +// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetDegradedCount(*v) + } + return _u +} + +// AddDegradedCount adds value to the "degraded_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddDegradedCount(v) + return _u +} + +// SetFailedCount sets the "failed_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetFailedCount() + _u.mutation.SetFailedCount(v) + return _u +} + +// SetNillableFailedCount sets the "failed_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetFailedCount(*v) + } + return _u +} + +// AddFailedCount adds value to the "failed_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddFailedCount(v) + return _u +} + +// SetErrorCount sets the "error_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetErrorCount() + _u.mutation.SetErrorCount(v) + return _u +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetErrorCount(*v) + } + return _u +} + +// AddErrorCount adds value to the "error_count" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddErrorCount(v) + return _u +} + +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetSumLatencyMs() + _u.mutation.SetSumLatencyMs(v) + return _u +} + +// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetSumLatencyMs(*v) + } + return _u +} + +// AddSumLatencyMs adds value to the "sum_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddSumLatencyMs(v) + return _u +} + +// SetCountLatency sets the "count_latency" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetCountLatency() + _u.mutation.SetCountLatency(v) + return _u +} + +// SetNillableCountLatency sets the "count_latency" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetCountLatency(*v) + } + return _u +} + +// AddCountLatency adds value to the "count_latency" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddCountLatency(v) + return _u +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetSumPingLatencyMs() + _u.mutation.SetSumPingLatencyMs(v) + return _u +} + +// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetSumPingLatencyMs(*v) + } + return _u +} + +// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddSumPingLatencyMs(v) + return _u +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ResetCountPingLatency() + _u.mutation.SetCountPingLatency(v) + return _u +} + +// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdateOne { + if v != nil { + _u.SetCountPingLatency(*v) + } + return _u +} + +// AddCountPingLatency adds value to the "count_ping_latency" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.AddCountPingLatency(v) + return _u +} + +// SetComputedAt sets the "computed_at" field. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.SetComputedAt(v) + return _u +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdateOne { + return _u.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder. +func (_u *ChannelMonitorDailyRollupUpdateOne) Mutation() *ChannelMonitorDailyRollupMutation { + return _u.mutation +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorDailyRollupUpdateOne) ClearMonitor() *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.ClearMonitor() + return _u +} + +// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder. +func (_u *ChannelMonitorDailyRollupUpdateOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ChannelMonitorDailyRollupUpdateOne) Select(field string, fields ...string) *ChannelMonitorDailyRollupUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ChannelMonitorDailyRollup entity. +func (_u *ChannelMonitorDailyRollupUpdateOne) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorDailyRollupUpdateOne) SaveX(ctx context.Context) *ChannelMonitorDailyRollup { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ChannelMonitorDailyRollupUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorDailyRollupUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() { + if _, ok := _u.mutation.ComputedAt(); !ok { + v := channelmonitordailyrollup.UpdateDefaultComputedAt() + _u.mutation.SetComputedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorDailyRollupUpdateOne) check() error { + if v, ok := _u.mutation.Model(); ok { + if err := channelmonitordailyrollup.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)} + } + } + if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`) + } + return nil +} + +func (_u *ChannelMonitorDailyRollupUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorDailyRollup, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorDailyRollup.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID) + for _, f := range fields { + if !channelmonitordailyrollup.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != channelmonitordailyrollup.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.BucketDate(); ok { + _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value) + } + if value, ok := _u.mutation.TotalChecks(); ok { + _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTotalChecks(); ok { + _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value) + } + if value, ok := _u.mutation.OkCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOkCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value) + } + if value, ok := _u.mutation.OperationalCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOperationalCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.DegradedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDegradedCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.FailedCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFailedCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCount(); ok { + _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedErrorCount(); ok { + _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SumLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSumLatencyMs(); ok { + _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.CountLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCountLatency(); ok { + _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.SumPingLatencyMs(); ok { + _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok { + _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value) + } + if value, ok := _u.mutation.CountPingLatency(); ok { + _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCountPingLatency(); ok { + _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value) + } + if value, ok := _u.mutation.ComputedAt(); ok { + _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value) + } + if _u.mutation.MonitorCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitordailyrollup.MonitorTable, + Columns: []string{channelmonitordailyrollup.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitordailyrollup.MonitorTable, + Columns: []string{channelmonitordailyrollup.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &ChannelMonitorDailyRollup{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitordailyrollup.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/channelmonitorhistory.go b/backend/ent/channelmonitorhistory.go new file mode 100644 index 0000000000000000000000000000000000000000..70dde5422a7d2777d68f39b0cab797eebe40447e --- /dev/null +++ b/backend/ent/channelmonitorhistory.go @@ -0,0 +1,207 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" +) + +// ChannelMonitorHistory is the model entity for the ChannelMonitorHistory schema. +type ChannelMonitorHistory struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // MonitorID holds the value of the "monitor_id" field. + MonitorID int64 `json:"monitor_id,omitempty"` + // Model holds the value of the "model" field. + Model string `json:"model,omitempty"` + // Status holds the value of the "status" field. + Status channelmonitorhistory.Status `json:"status,omitempty"` + // LatencyMs holds the value of the "latency_ms" field. + LatencyMs *int `json:"latency_ms,omitempty"` + // PingLatencyMs holds the value of the "ping_latency_ms" field. + PingLatencyMs *int `json:"ping_latency_ms,omitempty"` + // Message holds the value of the "message" field. + Message string `json:"message,omitempty"` + // CheckedAt holds the value of the "checked_at" field. + CheckedAt time.Time `json:"checked_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ChannelMonitorHistoryQuery when eager-loading is set. + Edges ChannelMonitorHistoryEdges `json:"edges"` + selectValues sql.SelectValues +} + +// ChannelMonitorHistoryEdges holds the relations/edges for other nodes in the graph. +type ChannelMonitorHistoryEdges struct { + // Monitor holds the value of the monitor edge. + Monitor *ChannelMonitor `json:"monitor,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// MonitorOrErr returns the Monitor value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e ChannelMonitorHistoryEdges) MonitorOrErr() (*ChannelMonitor, error) { + if e.Monitor != nil { + return e.Monitor, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: channelmonitor.Label} + } + return nil, &NotLoadedError{edge: "monitor"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ChannelMonitorHistory) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case channelmonitorhistory.FieldID, channelmonitorhistory.FieldMonitorID, channelmonitorhistory.FieldLatencyMs, channelmonitorhistory.FieldPingLatencyMs: + values[i] = new(sql.NullInt64) + case channelmonitorhistory.FieldModel, channelmonitorhistory.FieldStatus, channelmonitorhistory.FieldMessage: + values[i] = new(sql.NullString) + case channelmonitorhistory.FieldCheckedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ChannelMonitorHistory fields. +func (_m *ChannelMonitorHistory) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case channelmonitorhistory.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case channelmonitorhistory.FieldMonitorID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field monitor_id", values[i]) + } else if value.Valid { + _m.MonitorID = value.Int64 + } + case channelmonitorhistory.FieldModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[i]) + } else if value.Valid { + _m.Model = value.String + } + case channelmonitorhistory.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = channelmonitorhistory.Status(value.String) + } + case channelmonitorhistory.FieldLatencyMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field latency_ms", values[i]) + } else if value.Valid { + _m.LatencyMs = new(int) + *_m.LatencyMs = int(value.Int64) + } + case channelmonitorhistory.FieldPingLatencyMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field ping_latency_ms", values[i]) + } else if value.Valid { + _m.PingLatencyMs = new(int) + *_m.PingLatencyMs = int(value.Int64) + } + case channelmonitorhistory.FieldMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field message", values[i]) + } else if value.Valid { + _m.Message = value.String + } + case channelmonitorhistory.FieldCheckedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field checked_at", values[i]) + } else if value.Valid { + _m.CheckedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorHistory. +// This includes values selected through modifiers, order, etc. +func (_m *ChannelMonitorHistory) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryMonitor queries the "monitor" edge of the ChannelMonitorHistory entity. +func (_m *ChannelMonitorHistory) QueryMonitor() *ChannelMonitorQuery { + return NewChannelMonitorHistoryClient(_m.config).QueryMonitor(_m) +} + +// Update returns a builder for updating this ChannelMonitorHistory. +// Note that you need to call ChannelMonitorHistory.Unwrap() before calling this method if this ChannelMonitorHistory +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ChannelMonitorHistory) Update() *ChannelMonitorHistoryUpdateOne { + return NewChannelMonitorHistoryClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ChannelMonitorHistory entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ChannelMonitorHistory) Unwrap() *ChannelMonitorHistory { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ChannelMonitorHistory is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ChannelMonitorHistory) String() string { + var builder strings.Builder + builder.WriteString("ChannelMonitorHistory(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("monitor_id=") + builder.WriteString(fmt.Sprintf("%v", _m.MonitorID)) + builder.WriteString(", ") + builder.WriteString("model=") + builder.WriteString(_m.Model) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", _m.Status)) + builder.WriteString(", ") + if v := _m.LatencyMs; v != nil { + builder.WriteString("latency_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.PingLatencyMs; v != nil { + builder.WriteString("ping_latency_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("message=") + builder.WriteString(_m.Message) + builder.WriteString(", ") + builder.WriteString("checked_at=") + builder.WriteString(_m.CheckedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// ChannelMonitorHistories is a parsable slice of ChannelMonitorHistory. +type ChannelMonitorHistories []*ChannelMonitorHistory diff --git a/backend/ent/channelmonitorhistory/channelmonitorhistory.go b/backend/ent/channelmonitorhistory/channelmonitorhistory.go new file mode 100644 index 0000000000000000000000000000000000000000..6a9dc006703bef0adee45e61783f81115b76cd5f --- /dev/null +++ b/backend/ent/channelmonitorhistory/channelmonitorhistory.go @@ -0,0 +1,158 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitorhistory + +import ( + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the channelmonitorhistory type in the database. + Label = "channel_monitor_history" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldMonitorID holds the string denoting the monitor_id field in the database. + FieldMonitorID = "monitor_id" + // FieldModel holds the string denoting the model field in the database. + FieldModel = "model" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldLatencyMs holds the string denoting the latency_ms field in the database. + FieldLatencyMs = "latency_ms" + // FieldPingLatencyMs holds the string denoting the ping_latency_ms field in the database. + FieldPingLatencyMs = "ping_latency_ms" + // FieldMessage holds the string denoting the message field in the database. + FieldMessage = "message" + // FieldCheckedAt holds the string denoting the checked_at field in the database. + FieldCheckedAt = "checked_at" + // EdgeMonitor holds the string denoting the monitor edge name in mutations. + EdgeMonitor = "monitor" + // Table holds the table name of the channelmonitorhistory in the database. + Table = "channel_monitor_histories" + // MonitorTable is the table that holds the monitor relation/edge. + MonitorTable = "channel_monitor_histories" + // MonitorInverseTable is the table name for the ChannelMonitor entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitor" package. + MonitorInverseTable = "channel_monitors" + // MonitorColumn is the table column denoting the monitor relation/edge. + MonitorColumn = "monitor_id" +) + +// Columns holds all SQL columns for channelmonitorhistory fields. +var Columns = []string{ + FieldID, + FieldMonitorID, + FieldModel, + FieldStatus, + FieldLatencyMs, + FieldPingLatencyMs, + FieldMessage, + FieldCheckedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // ModelValidator is a validator for the "model" field. It is called by the builders before save. + ModelValidator func(string) error + // DefaultMessage holds the default value on creation for the "message" field. + DefaultMessage string + // MessageValidator is a validator for the "message" field. It is called by the builders before save. + MessageValidator func(string) error + // DefaultCheckedAt holds the default value on creation for the "checked_at" field. + DefaultCheckedAt func() time.Time +) + +// Status defines the type for the "status" enum field. +type Status string + +// Status values. +const ( + StatusOperational Status = "operational" + StatusDegraded Status = "degraded" + StatusFailed Status = "failed" + StatusError Status = "error" +) + +func (s Status) String() string { + return string(s) +} + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s Status) error { + switch s { + case StatusOperational, StatusDegraded, StatusFailed, StatusError: + return nil + default: + return fmt.Errorf("channelmonitorhistory: invalid enum value for status field: %q", s) + } +} + +// OrderOption defines the ordering options for the ChannelMonitorHistory queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByMonitorID orders the results by the monitor_id field. +func ByMonitorID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonitorID, opts...).ToFunc() +} + +// ByModel orders the results by the model field. +func ByModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModel, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByLatencyMs orders the results by the latency_ms field. +func ByLatencyMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLatencyMs, opts...).ToFunc() +} + +// ByPingLatencyMs orders the results by the ping_latency_ms field. +func ByPingLatencyMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPingLatencyMs, opts...).ToFunc() +} + +// ByMessage orders the results by the message field. +func ByMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMessage, opts...).ToFunc() +} + +// ByCheckedAt orders the results by the checked_at field. +func ByCheckedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCheckedAt, opts...).ToFunc() +} + +// ByMonitorField orders the results by monitor field. +func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...)) + } +} +func newMonitorStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MonitorInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn), + ) +} diff --git a/backend/ent/channelmonitorhistory/where.go b/backend/ent/channelmonitorhistory/where.go new file mode 100644 index 0000000000000000000000000000000000000000..afa73f35c87594111eb4779c69356f19ebcbfa76 --- /dev/null +++ b/backend/ent/channelmonitorhistory/where.go @@ -0,0 +1,444 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitorhistory + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldID, id)) +} + +// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ. +func MonitorID(v int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v)) +} + +// Model applies equality check predicate on the "model" field. It's identical to ModelEQ. +func Model(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v)) +} + +// LatencyMs applies equality check predicate on the "latency_ms" field. It's identical to LatencyMsEQ. +func LatencyMs(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v)) +} + +// PingLatencyMs applies equality check predicate on the "ping_latency_ms" field. It's identical to PingLatencyMsEQ. +func PingLatencyMs(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v)) +} + +// Message applies equality check predicate on the "message" field. It's identical to MessageEQ. +func Message(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v)) +} + +// CheckedAt applies equality check predicate on the "checked_at" field. It's identical to CheckedAtEQ. +func CheckedAt(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v)) +} + +// MonitorIDEQ applies the EQ predicate on the "monitor_id" field. +func MonitorIDEQ(v int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v)) +} + +// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field. +func MonitorIDNEQ(v int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMonitorID, v)) +} + +// MonitorIDIn applies the In predicate on the "monitor_id" field. +func MonitorIDIn(vs ...int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMonitorID, vs...)) +} + +// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field. +func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMonitorID, vs...)) +} + +// ModelEQ applies the EQ predicate on the "model" field. +func ModelEQ(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v)) +} + +// ModelNEQ applies the NEQ predicate on the "model" field. +func ModelNEQ(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldModel, v)) +} + +// ModelIn applies the In predicate on the "model" field. +func ModelIn(vs ...string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldModel, vs...)) +} + +// ModelNotIn applies the NotIn predicate on the "model" field. +func ModelNotIn(vs ...string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldModel, vs...)) +} + +// ModelGT applies the GT predicate on the "model" field. +func ModelGT(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldModel, v)) +} + +// ModelGTE applies the GTE predicate on the "model" field. +func ModelGTE(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldModel, v)) +} + +// ModelLT applies the LT predicate on the "model" field. +func ModelLT(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldModel, v)) +} + +// ModelLTE applies the LTE predicate on the "model" field. +func ModelLTE(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldModel, v)) +} + +// ModelContains applies the Contains predicate on the "model" field. +func ModelContains(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldContains(FieldModel, v)) +} + +// ModelHasPrefix applies the HasPrefix predicate on the "model" field. +func ModelHasPrefix(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldModel, v)) +} + +// ModelHasSuffix applies the HasSuffix predicate on the "model" field. +func ModelHasSuffix(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldModel, v)) +} + +// ModelEqualFold applies the EqualFold predicate on the "model" field. +func ModelEqualFold(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldModel, v)) +} + +// ModelContainsFold applies the ContainsFold predicate on the "model" field. +func ModelContainsFold(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldModel, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v Status) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v Status) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...Status) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...Status) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldStatus, vs...)) +} + +// LatencyMsEQ applies the EQ predicate on the "latency_ms" field. +func LatencyMsEQ(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v)) +} + +// LatencyMsNEQ applies the NEQ predicate on the "latency_ms" field. +func LatencyMsNEQ(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldLatencyMs, v)) +} + +// LatencyMsIn applies the In predicate on the "latency_ms" field. +func LatencyMsIn(vs ...int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldLatencyMs, vs...)) +} + +// LatencyMsNotIn applies the NotIn predicate on the "latency_ms" field. +func LatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldLatencyMs, vs...)) +} + +// LatencyMsGT applies the GT predicate on the "latency_ms" field. +func LatencyMsGT(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldLatencyMs, v)) +} + +// LatencyMsGTE applies the GTE predicate on the "latency_ms" field. +func LatencyMsGTE(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldLatencyMs, v)) +} + +// LatencyMsLT applies the LT predicate on the "latency_ms" field. +func LatencyMsLT(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldLatencyMs, v)) +} + +// LatencyMsLTE applies the LTE predicate on the "latency_ms" field. +func LatencyMsLTE(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldLatencyMs, v)) +} + +// LatencyMsIsNil applies the IsNil predicate on the "latency_ms" field. +func LatencyMsIsNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldLatencyMs)) +} + +// LatencyMsNotNil applies the NotNil predicate on the "latency_ms" field. +func LatencyMsNotNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldLatencyMs)) +} + +// PingLatencyMsEQ applies the EQ predicate on the "ping_latency_ms" field. +func PingLatencyMsEQ(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v)) +} + +// PingLatencyMsNEQ applies the NEQ predicate on the "ping_latency_ms" field. +func PingLatencyMsNEQ(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldPingLatencyMs, v)) +} + +// PingLatencyMsIn applies the In predicate on the "ping_latency_ms" field. +func PingLatencyMsIn(vs ...int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldPingLatencyMs, vs...)) +} + +// PingLatencyMsNotIn applies the NotIn predicate on the "ping_latency_ms" field. +func PingLatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldPingLatencyMs, vs...)) +} + +// PingLatencyMsGT applies the GT predicate on the "ping_latency_ms" field. +func PingLatencyMsGT(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldPingLatencyMs, v)) +} + +// PingLatencyMsGTE applies the GTE predicate on the "ping_latency_ms" field. +func PingLatencyMsGTE(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldPingLatencyMs, v)) +} + +// PingLatencyMsLT applies the LT predicate on the "ping_latency_ms" field. +func PingLatencyMsLT(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldPingLatencyMs, v)) +} + +// PingLatencyMsLTE applies the LTE predicate on the "ping_latency_ms" field. +func PingLatencyMsLTE(v int) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldPingLatencyMs, v)) +} + +// PingLatencyMsIsNil applies the IsNil predicate on the "ping_latency_ms" field. +func PingLatencyMsIsNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldPingLatencyMs)) +} + +// PingLatencyMsNotNil applies the NotNil predicate on the "ping_latency_ms" field. +func PingLatencyMsNotNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldPingLatencyMs)) +} + +// MessageEQ applies the EQ predicate on the "message" field. +func MessageEQ(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v)) +} + +// MessageNEQ applies the NEQ predicate on the "message" field. +func MessageNEQ(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMessage, v)) +} + +// MessageIn applies the In predicate on the "message" field. +func MessageIn(vs ...string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMessage, vs...)) +} + +// MessageNotIn applies the NotIn predicate on the "message" field. +func MessageNotIn(vs ...string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMessage, vs...)) +} + +// MessageGT applies the GT predicate on the "message" field. +func MessageGT(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldMessage, v)) +} + +// MessageGTE applies the GTE predicate on the "message" field. +func MessageGTE(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldMessage, v)) +} + +// MessageLT applies the LT predicate on the "message" field. +func MessageLT(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldMessage, v)) +} + +// MessageLTE applies the LTE predicate on the "message" field. +func MessageLTE(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldMessage, v)) +} + +// MessageContains applies the Contains predicate on the "message" field. +func MessageContains(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldContains(FieldMessage, v)) +} + +// MessageHasPrefix applies the HasPrefix predicate on the "message" field. +func MessageHasPrefix(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldMessage, v)) +} + +// MessageHasSuffix applies the HasSuffix predicate on the "message" field. +func MessageHasSuffix(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldMessage, v)) +} + +// MessageIsNil applies the IsNil predicate on the "message" field. +func MessageIsNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldMessage)) +} + +// MessageNotNil applies the NotNil predicate on the "message" field. +func MessageNotNil() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldMessage)) +} + +// MessageEqualFold applies the EqualFold predicate on the "message" field. +func MessageEqualFold(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldMessage, v)) +} + +// MessageContainsFold applies the ContainsFold predicate on the "message" field. +func MessageContainsFold(v string) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldMessage, v)) +} + +// CheckedAtEQ applies the EQ predicate on the "checked_at" field. +func CheckedAtEQ(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v)) +} + +// CheckedAtNEQ applies the NEQ predicate on the "checked_at" field. +func CheckedAtNEQ(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldCheckedAt, v)) +} + +// CheckedAtIn applies the In predicate on the "checked_at" field. +func CheckedAtIn(vs ...time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldIn(FieldCheckedAt, vs...)) +} + +// CheckedAtNotIn applies the NotIn predicate on the "checked_at" field. +func CheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldCheckedAt, vs...)) +} + +// CheckedAtGT applies the GT predicate on the "checked_at" field. +func CheckedAtGT(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGT(FieldCheckedAt, v)) +} + +// CheckedAtGTE applies the GTE predicate on the "checked_at" field. +func CheckedAtGTE(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldCheckedAt, v)) +} + +// CheckedAtLT applies the LT predicate on the "checked_at" field. +func CheckedAtLT(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLT(FieldCheckedAt, v)) +} + +// CheckedAtLTE applies the LTE predicate on the "checked_at" field. +func CheckedAtLTE(v time.Time) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldCheckedAt, v)) +} + +// HasMonitor applies the HasEdge predicate on the "monitor" edge. +func HasMonitor() predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates). +func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(func(s *sql.Selector) { + step := newMonitorStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory { + return predicate.ChannelMonitorHistory(sql.NotPredicates(p)) +} diff --git a/backend/ent/channelmonitorhistory_create.go b/backend/ent/channelmonitorhistory_create.go new file mode 100644 index 0000000000000000000000000000000000000000..71034865c97017a5babcca6fe84815cfa07841f0 --- /dev/null +++ b/backend/ent/channelmonitorhistory_create.go @@ -0,0 +1,947 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" +) + +// ChannelMonitorHistoryCreate is the builder for creating a ChannelMonitorHistory entity. +type ChannelMonitorHistoryCreate struct { + config + mutation *ChannelMonitorHistoryMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetMonitorID sets the "monitor_id" field. +func (_c *ChannelMonitorHistoryCreate) SetMonitorID(v int64) *ChannelMonitorHistoryCreate { + _c.mutation.SetMonitorID(v) + return _c +} + +// SetModel sets the "model" field. +func (_c *ChannelMonitorHistoryCreate) SetModel(v string) *ChannelMonitorHistoryCreate { + _c.mutation.SetModel(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *ChannelMonitorHistoryCreate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetLatencyMs sets the "latency_ms" field. +func (_c *ChannelMonitorHistoryCreate) SetLatencyMs(v int) *ChannelMonitorHistoryCreate { + _c.mutation.SetLatencyMs(v) + return _c +} + +// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil. +func (_c *ChannelMonitorHistoryCreate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryCreate { + if v != nil { + _c.SetLatencyMs(*v) + } + return _c +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (_c *ChannelMonitorHistoryCreate) SetPingLatencyMs(v int) *ChannelMonitorHistoryCreate { + _c.mutation.SetPingLatencyMs(v) + return _c +} + +// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil. +func (_c *ChannelMonitorHistoryCreate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryCreate { + if v != nil { + _c.SetPingLatencyMs(*v) + } + return _c +} + +// SetMessage sets the "message" field. +func (_c *ChannelMonitorHistoryCreate) SetMessage(v string) *ChannelMonitorHistoryCreate { + _c.mutation.SetMessage(v) + return _c +} + +// SetNillableMessage sets the "message" field if the given value is not nil. +func (_c *ChannelMonitorHistoryCreate) SetNillableMessage(v *string) *ChannelMonitorHistoryCreate { + if v != nil { + _c.SetMessage(*v) + } + return _c +} + +// SetCheckedAt sets the "checked_at" field. +func (_c *ChannelMonitorHistoryCreate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryCreate { + _c.mutation.SetCheckedAt(v) + return _c +} + +// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil. +func (_c *ChannelMonitorHistoryCreate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryCreate { + if v != nil { + _c.SetCheckedAt(*v) + } + return _c +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_c *ChannelMonitorHistoryCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryCreate { + return _c.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorHistoryMutation object of the builder. +func (_c *ChannelMonitorHistoryCreate) Mutation() *ChannelMonitorHistoryMutation { + return _c.mutation +} + +// Save creates the ChannelMonitorHistory in the database. +func (_c *ChannelMonitorHistoryCreate) Save(ctx context.Context) (*ChannelMonitorHistory, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ChannelMonitorHistoryCreate) SaveX(ctx context.Context) *ChannelMonitorHistory { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorHistoryCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorHistoryCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ChannelMonitorHistoryCreate) defaults() { + if _, ok := _c.mutation.Message(); !ok { + v := channelmonitorhistory.DefaultMessage + _c.mutation.SetMessage(v) + } + if _, ok := _c.mutation.CheckedAt(); !ok { + v := channelmonitorhistory.DefaultCheckedAt() + _c.mutation.SetCheckedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ChannelMonitorHistoryCreate) check() error { + if _, ok := _c.mutation.MonitorID(); !ok { + return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorHistory.monitor_id"`)} + } + if _, ok := _c.mutation.Model(); !ok { + return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorHistory.model"`)} + } + if v, ok := _c.mutation.Model(); ok { + if err := channelmonitorhistory.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)} + } + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "ChannelMonitorHistory.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := channelmonitorhistory.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)} + } + } + if v, ok := _c.mutation.Message(); ok { + if err := channelmonitorhistory.MessageValidator(v); err != nil { + return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)} + } + } + if _, ok := _c.mutation.CheckedAt(); !ok { + return &ValidationError{Name: "checked_at", err: errors.New(`ent: missing required field "ChannelMonitorHistory.checked_at"`)} + } + if len(_c.mutation.MonitorIDs()) == 0 { + return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorHistory.monitor"`)} + } + return nil +} + +func (_c *ChannelMonitorHistoryCreate) sqlSave(ctx context.Context) (*ChannelMonitorHistory, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sqlgraph.CreateSpec) { + var ( + _node = &ChannelMonitorHistory{config: _c.config} + _spec = sqlgraph.NewCreateSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.Model(); ok { + _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value) + _node.Model = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := _c.mutation.LatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value) + _node.LatencyMs = &value + } + if value, ok := _c.mutation.PingLatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value) + _node.PingLatencyMs = &value + } + if value, ok := _c.mutation.Message(); ok { + _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value) + _node.Message = value + } + if value, ok := _c.mutation.CheckedAt(); ok { + _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value) + _node.CheckedAt = value + } + if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitorhistory.MonitorTable, + Columns: []string{channelmonitorhistory.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.MonitorID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorHistory.Create(). +// SetMonitorID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorHistoryUpsert) { +// SetMonitorID(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorHistoryCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertOne { + _c.conflict = opts + return &ChannelMonitorHistoryUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorHistoryCreate) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorHistoryUpsertOne{ + create: _c, + } +} + +type ( + // ChannelMonitorHistoryUpsertOne is the builder for "upsert"-ing + // one ChannelMonitorHistory node. + ChannelMonitorHistoryUpsertOne struct { + create *ChannelMonitorHistoryCreate + } + + // ChannelMonitorHistoryUpsert is the "OnConflict" setter. + ChannelMonitorHistoryUpsert struct { + *sql.UpdateSet + } +) + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorHistoryUpsert) SetMonitorID(v int64) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldMonitorID, v) + return u +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateMonitorID() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldMonitorID) + return u +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorHistoryUpsert) SetModel(v string) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldModel, v) + return u +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateModel() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldModel) + return u +} + +// SetStatus sets the "status" field. +func (u *ChannelMonitorHistoryUpsert) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateStatus() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldStatus) + return u +} + +// SetLatencyMs sets the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) SetLatencyMs(v int) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldLatencyMs, v) + return u +} + +// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateLatencyMs() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldLatencyMs) + return u +} + +// AddLatencyMs adds v to the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) AddLatencyMs(v int) *ChannelMonitorHistoryUpsert { + u.Add(channelmonitorhistory.FieldLatencyMs, v) + return u +} + +// ClearLatencyMs clears the value of the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) ClearLatencyMs() *ChannelMonitorHistoryUpsert { + u.SetNull(channelmonitorhistory.FieldLatencyMs) + return u +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldPingLatencyMs, v) + return u +} + +// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldPingLatencyMs) + return u +} + +// AddPingLatencyMs adds v to the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsert { + u.Add(channelmonitorhistory.FieldPingLatencyMs, v) + return u +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsert) ClearPingLatencyMs() *ChannelMonitorHistoryUpsert { + u.SetNull(channelmonitorhistory.FieldPingLatencyMs) + return u +} + +// SetMessage sets the "message" field. +func (u *ChannelMonitorHistoryUpsert) SetMessage(v string) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldMessage, v) + return u +} + +// UpdateMessage sets the "message" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateMessage() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldMessage) + return u +} + +// ClearMessage clears the value of the "message" field. +func (u *ChannelMonitorHistoryUpsert) ClearMessage() *ChannelMonitorHistoryUpsert { + u.SetNull(channelmonitorhistory.FieldMessage) + return u +} + +// SetCheckedAt sets the "checked_at" field. +func (u *ChannelMonitorHistoryUpsert) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsert { + u.Set(channelmonitorhistory.FieldCheckedAt, v) + return u +} + +// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsert) UpdateCheckedAt() *ChannelMonitorHistoryUpsert { + u.SetExcluded(channelmonitorhistory.FieldCheckedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorHistoryUpsertOne) UpdateNewValues() *ChannelMonitorHistoryUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorHistoryUpsertOne) Ignore() *ChannelMonitorHistoryUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorHistoryUpsertOne) DoNothing() *ChannelMonitorHistoryUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreate.OnConflict +// documentation for more info. +func (u *ChannelMonitorHistoryUpsertOne) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorHistoryUpsert{UpdateSet: update}) + })) + return u +} + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorHistoryUpsertOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetMonitorID(v) + }) +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateMonitorID() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateMonitorID() + }) +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorHistoryUpsertOne) SetModel(v string) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateModel() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateModel() + }) +} + +// SetStatus sets the "status" field. +func (u *ChannelMonitorHistoryUpsertOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateStatus() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateStatus() + }) +} + +// SetLatencyMs sets the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetLatencyMs(v) + }) +} + +// AddLatencyMs adds v to the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.AddLatencyMs(v) + }) +} + +// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateLatencyMs() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateLatencyMs() + }) +} + +// ClearLatencyMs clears the value of the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) ClearLatencyMs() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearLatencyMs() + }) +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetPingLatencyMs(v) + }) +} + +// AddPingLatencyMs adds v to the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.AddPingLatencyMs(v) + }) +} + +// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdatePingLatencyMs() + }) +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearPingLatencyMs() + }) +} + +// SetMessage sets the "message" field. +func (u *ChannelMonitorHistoryUpsertOne) SetMessage(v string) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetMessage(v) + }) +} + +// UpdateMessage sets the "message" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateMessage() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateMessage() + }) +} + +// ClearMessage clears the value of the "message" field. +func (u *ChannelMonitorHistoryUpsertOne) ClearMessage() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearMessage() + }) +} + +// SetCheckedAt sets the "checked_at" field. +func (u *ChannelMonitorHistoryUpsertOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetCheckedAt(v) + }) +} + +// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertOne) UpdateCheckedAt() *ChannelMonitorHistoryUpsertOne { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateCheckedAt() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorHistoryUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorHistoryCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorHistoryUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ChannelMonitorHistoryUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ChannelMonitorHistoryUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ChannelMonitorHistoryCreateBulk is the builder for creating many ChannelMonitorHistory entities in bulk. +type ChannelMonitorHistoryCreateBulk struct { + config + err error + builders []*ChannelMonitorHistoryCreate + conflict []sql.ConflictOption +} + +// Save creates the ChannelMonitorHistory entities in the database. +func (_c *ChannelMonitorHistoryCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorHistory, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ChannelMonitorHistory, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ChannelMonitorHistoryMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ChannelMonitorHistoryCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorHistory { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorHistoryCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorHistoryCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorHistory.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorHistoryUpsert) { +// SetMonitorID(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorHistoryCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertBulk { + _c.conflict = opts + return &ChannelMonitorHistoryUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorHistoryCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorHistoryUpsertBulk{ + create: _c, + } +} + +// ChannelMonitorHistoryUpsertBulk is the builder for "upsert"-ing +// a bulk of ChannelMonitorHistory nodes. +type ChannelMonitorHistoryUpsertBulk struct { + create *ChannelMonitorHistoryCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorHistoryUpsertBulk) UpdateNewValues() *ChannelMonitorHistoryUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorHistory.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorHistoryUpsertBulk) Ignore() *ChannelMonitorHistoryUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorHistoryUpsertBulk) DoNothing() *ChannelMonitorHistoryUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreateBulk.OnConflict +// documentation for more info. +func (u *ChannelMonitorHistoryUpsertBulk) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorHistoryUpsert{UpdateSet: update}) + })) + return u +} + +// SetMonitorID sets the "monitor_id" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetMonitorID(v) + }) +} + +// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateMonitorID() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateMonitorID() + }) +} + +// SetModel sets the "model" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetModel(v string) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateModel() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateModel() + }) +} + +// SetStatus sets the "status" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateStatus() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateStatus() + }) +} + +// SetLatencyMs sets the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetLatencyMs(v) + }) +} + +// AddLatencyMs adds v to the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.AddLatencyMs(v) + }) +} + +// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateLatencyMs() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateLatencyMs() + }) +} + +// ClearLatencyMs clears the value of the "latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) ClearLatencyMs() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearLatencyMs() + }) +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetPingLatencyMs(v) + }) +} + +// AddPingLatencyMs adds v to the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.AddPingLatencyMs(v) + }) +} + +// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdatePingLatencyMs() + }) +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (u *ChannelMonitorHistoryUpsertBulk) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearPingLatencyMs() + }) +} + +// SetMessage sets the "message" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetMessage(v string) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetMessage(v) + }) +} + +// UpdateMessage sets the "message" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateMessage() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateMessage() + }) +} + +// ClearMessage clears the value of the "message" field. +func (u *ChannelMonitorHistoryUpsertBulk) ClearMessage() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.ClearMessage() + }) +} + +// SetCheckedAt sets the "checked_at" field. +func (u *ChannelMonitorHistoryUpsertBulk) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.SetCheckedAt(v) + }) +} + +// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create. +func (u *ChannelMonitorHistoryUpsertBulk) UpdateCheckedAt() *ChannelMonitorHistoryUpsertBulk { + return u.Update(func(s *ChannelMonitorHistoryUpsert) { + s.UpdateCheckedAt() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorHistoryUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorHistoryCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorHistoryCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorHistoryUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitorhistory_delete.go b/backend/ent/channelmonitorhistory_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..97110e69191f8a4f493068c87a8d17326dbab871 --- /dev/null +++ b/backend/ent/channelmonitorhistory_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorHistoryDelete is the builder for deleting a ChannelMonitorHistory entity. +type ChannelMonitorHistoryDelete struct { + config + hooks []Hook + mutation *ChannelMonitorHistoryMutation +} + +// Where appends a list predicates to the ChannelMonitorHistoryDelete builder. +func (_d *ChannelMonitorHistoryDelete) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ChannelMonitorHistoryDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorHistoryDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ChannelMonitorHistoryDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ChannelMonitorHistoryDeleteOne is the builder for deleting a single ChannelMonitorHistory entity. +type ChannelMonitorHistoryDeleteOne struct { + _d *ChannelMonitorHistoryDelete +} + +// Where appends a list predicates to the ChannelMonitorHistoryDelete builder. +func (_d *ChannelMonitorHistoryDeleteOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ChannelMonitorHistoryDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{channelmonitorhistory.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorHistoryDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitorhistory_query.go b/backend/ent/channelmonitorhistory_query.go new file mode 100644 index 0000000000000000000000000000000000000000..1fb872ad1b14582005bdd4b000a5e3c593450c1a --- /dev/null +++ b/backend/ent/channelmonitorhistory_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorHistoryQuery is the builder for querying ChannelMonitorHistory entities. +type ChannelMonitorHistoryQuery struct { + config + ctx *QueryContext + order []channelmonitorhistory.OrderOption + inters []Interceptor + predicates []predicate.ChannelMonitorHistory + withMonitor *ChannelMonitorQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ChannelMonitorHistoryQuery builder. +func (_q *ChannelMonitorHistoryQuery) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ChannelMonitorHistoryQuery) Limit(limit int) *ChannelMonitorHistoryQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ChannelMonitorHistoryQuery) Offset(offset int) *ChannelMonitorHistoryQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ChannelMonitorHistoryQuery) Unique(unique bool) *ChannelMonitorHistoryQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ChannelMonitorHistoryQuery) Order(o ...channelmonitorhistory.OrderOption) *ChannelMonitorHistoryQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryMonitor chains the current query on the "monitor" edge. +func (_q *ChannelMonitorHistoryQuery) QueryMonitor() *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, selector), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first ChannelMonitorHistory entity from the query. +// Returns a *NotFoundError when no ChannelMonitorHistory was found. +func (_q *ChannelMonitorHistoryQuery) First(ctx context.Context) (*ChannelMonitorHistory, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{channelmonitorhistory.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) FirstX(ctx context.Context) *ChannelMonitorHistory { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ChannelMonitorHistory ID from the query. +// Returns a *NotFoundError when no ChannelMonitorHistory ID was found. +func (_q *ChannelMonitorHistoryQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{channelmonitorhistory.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ChannelMonitorHistory entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ChannelMonitorHistory entity is found. +// Returns a *NotFoundError when no ChannelMonitorHistory entities are found. +func (_q *ChannelMonitorHistoryQuery) Only(ctx context.Context) (*ChannelMonitorHistory, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{channelmonitorhistory.Label} + default: + return nil, &NotSingularError{channelmonitorhistory.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) OnlyX(ctx context.Context) *ChannelMonitorHistory { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ChannelMonitorHistory ID in the query. +// Returns a *NotSingularError when more than one ChannelMonitorHistory ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ChannelMonitorHistoryQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{channelmonitorhistory.Label} + default: + err = &NotSingularError{channelmonitorhistory.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ChannelMonitorHistories. +func (_q *ChannelMonitorHistoryQuery) All(ctx context.Context) ([]*ChannelMonitorHistory, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ChannelMonitorHistory, *ChannelMonitorHistoryQuery]() + return withInterceptors[[]*ChannelMonitorHistory](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) AllX(ctx context.Context) []*ChannelMonitorHistory { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ChannelMonitorHistory IDs. +func (_q *ChannelMonitorHistoryQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(channelmonitorhistory.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ChannelMonitorHistoryQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorHistoryQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ChannelMonitorHistoryQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ChannelMonitorHistoryQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ChannelMonitorHistoryQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ChannelMonitorHistoryQuery) Clone() *ChannelMonitorHistoryQuery { + if _q == nil { + return nil + } + return &ChannelMonitorHistoryQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]channelmonitorhistory.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ChannelMonitorHistory{}, _q.predicates...), + withMonitor: _q.withMonitor.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithMonitor tells the query-builder to eager-load the nodes that are connected to +// the "monitor" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorHistoryQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorHistoryQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withMonitor = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// MonitorID int64 `json:"monitor_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ChannelMonitorHistory.Query(). +// GroupBy(channelmonitorhistory.FieldMonitorID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *ChannelMonitorHistoryGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ChannelMonitorHistoryGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = channelmonitorhistory.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// MonitorID int64 `json:"monitor_id,omitempty"` +// } +// +// client.ChannelMonitorHistory.Query(). +// Select(channelmonitorhistory.FieldMonitorID). +// Scan(ctx, &v) +func (_q *ChannelMonitorHistoryQuery) Select(fields ...string) *ChannelMonitorHistorySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ChannelMonitorHistorySelect{ChannelMonitorHistoryQuery: _q} + sbuild.label = channelmonitorhistory.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ChannelMonitorHistorySelect configured with the given aggregations. +func (_q *ChannelMonitorHistoryQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ChannelMonitorHistoryQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !channelmonitorhistory.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ChannelMonitorHistoryQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorHistory, error) { + var ( + nodes = []*ChannelMonitorHistory{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withMonitor != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ChannelMonitorHistory).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ChannelMonitorHistory{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withMonitor; query != nil { + if err := _q.loadMonitor(ctx, query, nodes, nil, + func(n *ChannelMonitorHistory, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *ChannelMonitorHistoryQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorHistory, init func(*ChannelMonitorHistory), assign func(*ChannelMonitorHistory, *ChannelMonitor)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*ChannelMonitorHistory) + for i := range nodes { + fk := nodes[i].MonitorID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(channelmonitor.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *ChannelMonitorHistoryQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ChannelMonitorHistoryQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID) + for i := range fields { + if fields[i] != channelmonitorhistory.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withMonitor != nil { + _spec.Node.AddColumnOnce(channelmonitorhistory.FieldMonitorID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ChannelMonitorHistoryQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(channelmonitorhistory.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = channelmonitorhistory.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ChannelMonitorHistoryQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorHistoryQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ChannelMonitorHistoryQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorHistoryQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ChannelMonitorHistoryGroupBy is the group-by builder for ChannelMonitorHistory entities. +type ChannelMonitorHistoryGroupBy struct { + selector + build *ChannelMonitorHistoryQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ChannelMonitorHistoryGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistoryGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ChannelMonitorHistoryGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistoryGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ChannelMonitorHistoryGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ChannelMonitorHistorySelect is the builder for selecting fields of ChannelMonitorHistory entities. +type ChannelMonitorHistorySelect struct { + *ChannelMonitorHistoryQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ChannelMonitorHistorySelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ChannelMonitorHistorySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistorySelect](ctx, _s.ChannelMonitorHistoryQuery, _s, _s.inters, v) +} + +func (_s *ChannelMonitorHistorySelect) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/channelmonitorhistory_update.go b/backend/ent/channelmonitorhistory_update.go new file mode 100644 index 0000000000000000000000000000000000000000..a85a8072a9e6cbc0e7907f9e211eef103a6279dd --- /dev/null +++ b/backend/ent/channelmonitorhistory_update.go @@ -0,0 +1,635 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorHistoryUpdate is the builder for updating ChannelMonitorHistory entities. +type ChannelMonitorHistoryUpdate struct { + config + hooks []Hook + mutation *ChannelMonitorHistoryMutation +} + +// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder. +func (_u *ChannelMonitorHistoryUpdate) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetMonitorID sets the "monitor_id" field. +func (_u *ChannelMonitorHistoryUpdate) SetMonitorID(v int64) *ChannelMonitorHistoryUpdate { + _u.mutation.SetMonitorID(v) + return _u +} + +// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetMonitorID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *ChannelMonitorHistoryUpdate) SetModel(v string) *ChannelMonitorHistoryUpdate { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableModel(v *string) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *ChannelMonitorHistoryUpdate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetLatencyMs sets the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) SetLatencyMs(v int) *ChannelMonitorHistoryUpdate { + _u.mutation.ResetLatencyMs() + _u.mutation.SetLatencyMs(v) + return _u +} + +// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetLatencyMs(*v) + } + return _u +} + +// AddLatencyMs adds value to the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) AddLatencyMs(v int) *ChannelMonitorHistoryUpdate { + _u.mutation.AddLatencyMs(v) + return _u +} + +// ClearLatencyMs clears the value of the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) ClearLatencyMs() *ChannelMonitorHistoryUpdate { + _u.mutation.ClearLatencyMs() + return _u +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdate { + _u.mutation.ResetPingLatencyMs() + _u.mutation.SetPingLatencyMs(v) + return _u +} + +// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetPingLatencyMs(*v) + } + return _u +} + +// AddPingLatencyMs adds value to the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdate { + _u.mutation.AddPingLatencyMs(v) + return _u +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdate) ClearPingLatencyMs() *ChannelMonitorHistoryUpdate { + _u.mutation.ClearPingLatencyMs() + return _u +} + +// SetMessage sets the "message" field. +func (_u *ChannelMonitorHistoryUpdate) SetMessage(v string) *ChannelMonitorHistoryUpdate { + _u.mutation.SetMessage(v) + return _u +} + +// SetNillableMessage sets the "message" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetMessage(*v) + } + return _u +} + +// ClearMessage clears the value of the "message" field. +func (_u *ChannelMonitorHistoryUpdate) ClearMessage() *ChannelMonitorHistoryUpdate { + _u.mutation.ClearMessage() + return _u +} + +// SetCheckedAt sets the "checked_at" field. +func (_u *ChannelMonitorHistoryUpdate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdate { + _u.mutation.SetCheckedAt(v) + return _u +} + +// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdate { + if v != nil { + _u.SetCheckedAt(*v) + } + return _u +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorHistoryUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdate { + return _u.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorHistoryMutation object of the builder. +func (_u *ChannelMonitorHistoryUpdate) Mutation() *ChannelMonitorHistoryMutation { + return _u.mutation +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorHistoryUpdate) ClearMonitor() *ChannelMonitorHistoryUpdate { + _u.mutation.ClearMonitor() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ChannelMonitorHistoryUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorHistoryUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ChannelMonitorHistoryUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorHistoryUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorHistoryUpdate) check() error { + if v, ok := _u.mutation.Model(); ok { + if err := channelmonitorhistory.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := channelmonitorhistory.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)} + } + } + if v, ok := _u.mutation.Message(); ok { + if err := channelmonitorhistory.MessageValidator(v); err != nil { + return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)} + } + } + if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`) + } + return nil +} + +func (_u *ChannelMonitorHistoryUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value) + } + if value, ok := _u.mutation.LatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLatencyMs(); ok { + _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value) + } + if _u.mutation.LatencyMsCleared() { + _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt) + } + if value, ok := _u.mutation.PingLatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPingLatencyMs(); ok { + _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value) + } + if _u.mutation.PingLatencyMsCleared() { + _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt) + } + if value, ok := _u.mutation.Message(); ok { + _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value) + } + if _u.mutation.MessageCleared() { + _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString) + } + if value, ok := _u.mutation.CheckedAt(); ok { + _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value) + } + if _u.mutation.MonitorCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitorhistory.MonitorTable, + Columns: []string{channelmonitorhistory.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitorhistory.MonitorTable, + Columns: []string{channelmonitorhistory.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitorhistory.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ChannelMonitorHistoryUpdateOne is the builder for updating a single ChannelMonitorHistory entity. +type ChannelMonitorHistoryUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ChannelMonitorHistoryMutation +} + +// SetMonitorID sets the "monitor_id" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpdateOne { + _u.mutation.SetMonitorID(v) + return _u +} + +// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetMonitorID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetModel(v string) *ChannelMonitorHistoryUpdateOne { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableModel(v *string) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetLatencyMs sets the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpdateOne { + _u.mutation.ResetLatencyMs() + _u.mutation.SetLatencyMs(v) + return _u +} + +// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetLatencyMs(*v) + } + return _u +} + +// AddLatencyMs adds value to the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpdateOne { + _u.mutation.AddLatencyMs(v) + return _u +} + +// ClearLatencyMs clears the value of the "latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) ClearLatencyMs() *ChannelMonitorHistoryUpdateOne { + _u.mutation.ClearLatencyMs() + return _u +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne { + _u.mutation.ResetPingLatencyMs() + _u.mutation.SetPingLatencyMs(v) + return _u +} + +// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetPingLatencyMs(*v) + } + return _u +} + +// AddPingLatencyMs adds value to the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne { + _u.mutation.AddPingLatencyMs(v) + return _u +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (_u *ChannelMonitorHistoryUpdateOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpdateOne { + _u.mutation.ClearPingLatencyMs() + return _u +} + +// SetMessage sets the "message" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetMessage(v string) *ChannelMonitorHistoryUpdateOne { + _u.mutation.SetMessage(v) + return _u +} + +// SetNillableMessage sets the "message" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetMessage(*v) + } + return _u +} + +// ClearMessage clears the value of the "message" field. +func (_u *ChannelMonitorHistoryUpdateOne) ClearMessage() *ChannelMonitorHistoryUpdateOne { + _u.mutation.ClearMessage() + return _u +} + +// SetCheckedAt sets the "checked_at" field. +func (_u *ChannelMonitorHistoryUpdateOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdateOne { + _u.mutation.SetCheckedAt(v) + return _u +} + +// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil. +func (_u *ChannelMonitorHistoryUpdateOne) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdateOne { + if v != nil { + _u.SetCheckedAt(*v) + } + return _u +} + +// SetMonitor sets the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorHistoryUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdateOne { + return _u.SetMonitorID(v.ID) +} + +// Mutation returns the ChannelMonitorHistoryMutation object of the builder. +func (_u *ChannelMonitorHistoryUpdateOne) Mutation() *ChannelMonitorHistoryMutation { + return _u.mutation +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (_u *ChannelMonitorHistoryUpdateOne) ClearMonitor() *ChannelMonitorHistoryUpdateOne { + _u.mutation.ClearMonitor() + return _u +} + +// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder. +func (_u *ChannelMonitorHistoryUpdateOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ChannelMonitorHistoryUpdateOne) Select(field string, fields ...string) *ChannelMonitorHistoryUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ChannelMonitorHistory entity. +func (_u *ChannelMonitorHistoryUpdateOne) Save(ctx context.Context) (*ChannelMonitorHistory, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorHistoryUpdateOne) SaveX(ctx context.Context) *ChannelMonitorHistory { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ChannelMonitorHistoryUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorHistoryUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorHistoryUpdateOne) check() error { + if v, ok := _u.mutation.Model(); ok { + if err := channelmonitorhistory.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := channelmonitorhistory.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)} + } + } + if v, ok := _u.mutation.Message(); ok { + if err := channelmonitorhistory.MessageValidator(v); err != nil { + return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)} + } + } + if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`) + } + return nil +} + +func (_u *ChannelMonitorHistoryUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorHistory, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorHistory.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID) + for _, f := range fields { + if !channelmonitorhistory.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != channelmonitorhistory.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value) + } + if value, ok := _u.mutation.LatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedLatencyMs(); ok { + _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value) + } + if _u.mutation.LatencyMsCleared() { + _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt) + } + if value, ok := _u.mutation.PingLatencyMs(); ok { + _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPingLatencyMs(); ok { + _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value) + } + if _u.mutation.PingLatencyMsCleared() { + _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt) + } + if value, ok := _u.mutation.Message(); ok { + _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value) + } + if _u.mutation.MessageCleared() { + _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString) + } + if value, ok := _u.mutation.CheckedAt(); ok { + _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value) + } + if _u.mutation.MonitorCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitorhistory.MonitorTable, + Columns: []string{channelmonitorhistory.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: channelmonitorhistory.MonitorTable, + Columns: []string{channelmonitorhistory.MonitorColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &ChannelMonitorHistory{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitorhistory.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate.go new file mode 100644 index 0000000000000000000000000000000000000000..b8429a4d71efd1da12de7c654fc640ec1e460ee9 --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate.go @@ -0,0 +1,216 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" +) + +// ChannelMonitorRequestTemplate is the model entity for the ChannelMonitorRequestTemplate schema. +type ChannelMonitorRequestTemplate struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Provider holds the value of the "provider" field. + Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"` + // Description holds the value of the "description" field. + Description string `json:"description,omitempty"` + // ExtraHeaders holds the value of the "extra_headers" field. + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // BodyOverrideMode holds the value of the "body_override_mode" field. + BodyOverrideMode string `json:"body_override_mode,omitempty"` + // BodyOverride holds the value of the "body_override" field. + BodyOverride map[string]interface{} `json:"body_override,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ChannelMonitorRequestTemplateQuery when eager-loading is set. + Edges ChannelMonitorRequestTemplateEdges `json:"edges"` + selectValues sql.SelectValues +} + +// ChannelMonitorRequestTemplateEdges holds the relations/edges for other nodes in the graph. +type ChannelMonitorRequestTemplateEdges struct { + // Monitors holds the value of the monitors edge. + Monitors []*ChannelMonitor `json:"monitors,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// MonitorsOrErr returns the Monitors value or an error if the edge +// was not loaded in eager-loading. +func (e ChannelMonitorRequestTemplateEdges) MonitorsOrErr() ([]*ChannelMonitor, error) { + if e.loadedTypes[0] { + return e.Monitors, nil + } + return nil, &NotLoadedError{edge: "monitors"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case channelmonitorrequesttemplate.FieldExtraHeaders, channelmonitorrequesttemplate.FieldBodyOverride: + values[i] = new([]byte) + case channelmonitorrequesttemplate.FieldID: + values[i] = new(sql.NullInt64) + case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode: + values[i] = new(sql.NullString) + case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ChannelMonitorRequestTemplate fields. +func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case channelmonitorrequesttemplate.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case channelmonitorrequesttemplate.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case channelmonitorrequesttemplate.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case channelmonitorrequesttemplate.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case channelmonitorrequesttemplate.FieldProvider: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider", values[i]) + } else if value.Valid { + _m.Provider = channelmonitorrequesttemplate.Provider(value.String) + } + case channelmonitorrequesttemplate.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + _m.Description = value.String + } + case channelmonitorrequesttemplate.FieldExtraHeaders: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field extra_headers", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil { + return fmt.Errorf("unmarshal field extra_headers: %w", err) + } + } + case channelmonitorrequesttemplate.FieldBodyOverrideMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field body_override_mode", values[i]) + } else if value.Valid { + _m.BodyOverrideMode = value.String + } + case channelmonitorrequesttemplate.FieldBodyOverride: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field body_override", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil { + return fmt.Errorf("unmarshal field body_override: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorRequestTemplate. +// This includes values selected through modifiers, order, etc. +func (_m *ChannelMonitorRequestTemplate) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryMonitors queries the "monitors" edge of the ChannelMonitorRequestTemplate entity. +func (_m *ChannelMonitorRequestTemplate) QueryMonitors() *ChannelMonitorQuery { + return NewChannelMonitorRequestTemplateClient(_m.config).QueryMonitors(_m) +} + +// Update returns a builder for updating this ChannelMonitorRequestTemplate. +// Note that you need to call ChannelMonitorRequestTemplate.Unwrap() before calling this method if this ChannelMonitorRequestTemplate +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ChannelMonitorRequestTemplate) Update() *ChannelMonitorRequestTemplateUpdateOne { + return NewChannelMonitorRequestTemplateClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ChannelMonitorRequestTemplate entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ChannelMonitorRequestTemplate) Unwrap() *ChannelMonitorRequestTemplate { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ChannelMonitorRequestTemplate is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ChannelMonitorRequestTemplate) String() string { + var builder strings.Builder + builder.WriteString("ChannelMonitorRequestTemplate(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + builder.WriteString("provider=") + builder.WriteString(fmt.Sprintf("%v", _m.Provider)) + builder.WriteString(", ") + builder.WriteString("description=") + builder.WriteString(_m.Description) + builder.WriteString(", ") + builder.WriteString("extra_headers=") + builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders)) + builder.WriteString(", ") + builder.WriteString("body_override_mode=") + builder.WriteString(_m.BodyOverrideMode) + builder.WriteString(", ") + builder.WriteString("body_override=") + builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride)) + builder.WriteByte(')') + return builder.String() +} + +// ChannelMonitorRequestTemplates is a parsable slice of ChannelMonitorRequestTemplate. +type ChannelMonitorRequestTemplates []*ChannelMonitorRequestTemplate diff --git a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go new file mode 100644 index 0000000000000000000000000000000000000000..65b8d641cfc12191bf58afd79b7fc650583c0a90 --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go @@ -0,0 +1,172 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitorrequesttemplate + +import ( + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the channelmonitorrequesttemplate type in the database. + Label = "channel_monitor_request_template" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldProvider holds the string denoting the provider field in the database. + FieldProvider = "provider" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // FieldExtraHeaders holds the string denoting the extra_headers field in the database. + FieldExtraHeaders = "extra_headers" + // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database. + FieldBodyOverrideMode = "body_override_mode" + // FieldBodyOverride holds the string denoting the body_override field in the database. + FieldBodyOverride = "body_override" + // EdgeMonitors holds the string denoting the monitors edge name in mutations. + EdgeMonitors = "monitors" + // Table holds the table name of the channelmonitorrequesttemplate in the database. + Table = "channel_monitor_request_templates" + // MonitorsTable is the table that holds the monitors relation/edge. + MonitorsTable = "channel_monitors" + // MonitorsInverseTable is the table name for the ChannelMonitor entity. + // It exists in this package in order to avoid circular dependency with the "channelmonitor" package. + MonitorsInverseTable = "channel_monitors" + // MonitorsColumn is the table column denoting the monitors relation/edge. + MonitorsColumn = "template_id" +) + +// Columns holds all SQL columns for channelmonitorrequesttemplate fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldProvider, + FieldDescription, + FieldExtraHeaders, + FieldBodyOverrideMode, + FieldBodyOverride, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // DefaultDescription holds the default value on creation for the "description" field. + DefaultDescription string + // DescriptionValidator is a validator for the "description" field. It is called by the builders before save. + DescriptionValidator func(string) error + // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field. + DefaultExtraHeaders map[string]string + // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field. + DefaultBodyOverrideMode string + // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save. + BodyOverrideModeValidator func(string) error +) + +// Provider defines the type for the "provider" enum field. +type Provider string + +// Provider values. +const ( + ProviderOpenai Provider = "openai" + ProviderAnthropic Provider = "anthropic" + ProviderGemini Provider = "gemini" +) + +func (pr Provider) String() string { + return string(pr) +} + +// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save. +func ProviderValidator(pr Provider) error { + switch pr { + case ProviderOpenai, ProviderAnthropic, ProviderGemini: + return nil + default: + return fmt.Errorf("channelmonitorrequesttemplate: invalid enum value for provider field: %q", pr) + } +} + +// OrderOption defines the ordering options for the ChannelMonitorRequestTemplate queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByProvider orders the results by the provider field. +func ByProvider(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProvider, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} + +// ByBodyOverrideMode orders the results by the body_override_mode field. +func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc() +} + +// ByMonitorsCount orders the results by monitors count. +func ByMonitorsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newMonitorsStep(), opts...) + } +} + +// ByMonitors orders the results by monitors terms. +func ByMonitors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newMonitorsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newMonitorsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(MonitorsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn), + ) +} diff --git a/backend/ent/channelmonitorrequesttemplate/where.go b/backend/ent/channelmonitorrequesttemplate/where.go new file mode 100644 index 0000000000000000000000000000000000000000..b95e5df038945a8f69ea8394c2c12e86a67b2b2c --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate/where.go @@ -0,0 +1,434 @@ +// Code generated by ent, DO NOT EDIT. + +package channelmonitorrequesttemplate + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v)) +} + +// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ. +func BodyOverrideMode(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldName, v)) +} + +// ProviderEQ applies the EQ predicate on the "provider" field. +func ProviderEQ(v Provider) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldProvider, v)) +} + +// ProviderNEQ applies the NEQ predicate on the "provider" field. +func ProviderNEQ(v Provider) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldProvider, v)) +} + +// ProviderIn applies the In predicate on the "provider" field. +func ProviderIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldProvider, vs...)) +} + +// ProviderNotIn applies the NotIn predicate on the "provider" field. +func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldDescription, v)) +} + +// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field. +func BodyOverrideModeEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field. +func BodyOverrideModeNEQ(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field. +func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldBodyOverrideMode, vs...)) +} + +// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field. +func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldBodyOverrideMode, vs...)) +} + +// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field. +func BodyOverrideModeGT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field. +func BodyOverrideModeGTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field. +func BodyOverrideModeLT(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field. +func BodyOverrideModeLTE(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field. +func BodyOverrideModeContains(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field. +func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field. +func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field. +func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldBodyOverrideMode, v)) +} + +// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field. +func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldBodyOverrideMode, v)) +} + +// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field. +func BodyOverrideIsNil() predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldBodyOverride)) +} + +// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field. +func BodyOverrideNotNil() predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldBodyOverride)) +} + +// HasMonitors applies the HasEdge predicate on the "monitors" edge. +func HasMonitors() predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasMonitorsWith applies the HasEdge predicate on the "monitors" edge with a given conditions (other predicates). +func HasMonitorsWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) { + step := newMonitorsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate { + return predicate.ChannelMonitorRequestTemplate(sql.NotPredicates(p)) +} diff --git a/backend/ent/channelmonitorrequesttemplate_create.go b/backend/ent/channelmonitorrequesttemplate_create.go new file mode 100644 index 0000000000000000000000000000000000000000..1ba842cd7f56beae1ac4cb3a7f6fd93c6d4270fc --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate_create.go @@ -0,0 +1,942 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" +) + +// ChannelMonitorRequestTemplateCreate is the builder for creating a ChannelMonitorRequestTemplate entity. +type ChannelMonitorRequestTemplateCreate struct { + config + mutation *ChannelMonitorRequestTemplateMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetCreatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *ChannelMonitorRequestTemplateCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *ChannelMonitorRequestTemplateCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetName(v string) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetName(v) + return _c +} + +// SetProvider sets the "provider" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetProvider(v) + return _c +} + +// SetDescription sets the "description" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetDescription(v) + return _c +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_c *ChannelMonitorRequestTemplateCreate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateCreate { + if v != nil { + _c.SetDescription(*v) + } + return _c +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetExtraHeaders(v) + return _c +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetBodyOverrideMode(v) + return _c +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_c *ChannelMonitorRequestTemplateCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateCreate { + if v != nil { + _c.SetBodyOverrideMode(*v) + } + return _c +} + +// SetBodyOverride sets the "body_override" field. +func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateCreate { + _c.mutation.SetBodyOverride(v) + return _c +} + +// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs. +func (_c *ChannelMonitorRequestTemplateCreate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateCreate { + _c.mutation.AddMonitorIDs(ids...) + return _c +} + +// AddMonitors adds the "monitors" edges to the ChannelMonitor entity. +func (_c *ChannelMonitorRequestTemplateCreate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddMonitorIDs(ids...) +} + +// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder. +func (_c *ChannelMonitorRequestTemplateCreate) Mutation() *ChannelMonitorRequestTemplateMutation { + return _c.mutation +} + +// Save creates the ChannelMonitorRequestTemplate in the database. +func (_c *ChannelMonitorRequestTemplateCreate) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ChannelMonitorRequestTemplateCreate) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorRequestTemplateCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorRequestTemplateCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ChannelMonitorRequestTemplateCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := channelmonitorrequesttemplate.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := channelmonitorrequesttemplate.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Description(); !ok { + v := channelmonitorrequesttemplate.DefaultDescription + _c.mutation.SetDescription(v) + } + if _, ok := _c.mutation.ExtraHeaders(); !ok { + v := channelmonitorrequesttemplate.DefaultExtraHeaders + _c.mutation.SetExtraHeaders(v) + } + if _, ok := _c.mutation.BodyOverrideMode(); !ok { + v := channelmonitorrequesttemplate.DefaultBodyOverrideMode + _c.mutation.SetBodyOverrideMode(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ChannelMonitorRequestTemplateCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := channelmonitorrequesttemplate.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)} + } + } + if _, ok := _c.mutation.Provider(); !ok { + return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.provider"`)} + } + if v, ok := _c.mutation.Provider(); ok { + if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)} + } + } + if v, ok := _c.mutation.Description(); ok { + if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil { + return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)} + } + } + if _, ok := _c.mutation.ExtraHeaders(); !ok { + return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.extra_headers"`)} + } + if _, ok := _c.mutation.BodyOverrideMode(); !ok { + return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.body_override_mode"`)} + } + if v, ok := _c.mutation.BodyOverrideMode(); ok { + if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_c *ChannelMonitorRequestTemplateCreate) sqlSave(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequestTemplate, *sqlgraph.CreateSpec) { + var ( + _node = &ChannelMonitorRequestTemplate{config: _c.config} + _spec = sqlgraph.NewCreateSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Provider(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value) + _node.Provider = value + } + if value, ok := _c.mutation.Description(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value) + _node.Description = value + } + if value, ok := _c.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value) + _node.ExtraHeaders = value + } + if value, ok := _c.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value) + _node.BodyOverrideMode = value + } + if value, ok := _c.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value) + _node.BodyOverride = value + } + if nodes := _c.mutation.MonitorsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorRequestTemplate.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorRequestTemplateCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertOne { + _c.conflict = opts + return &ChannelMonitorRequestTemplateUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorRequestTemplateCreate) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorRequestTemplateUpsertOne{ + create: _c, + } +} + +type ( + // ChannelMonitorRequestTemplateUpsertOne is the builder for "upsert"-ing + // one ChannelMonitorRequestTemplate node. + ChannelMonitorRequestTemplateUpsertOne struct { + create *ChannelMonitorRequestTemplateCreate + } + + // ChannelMonitorRequestTemplateUpsert is the "OnConflict" setter. + ChannelMonitorRequestTemplateUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetName(v string) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateName() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldName) + return u +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldProvider, v) + return u +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldProvider) + return u +} + +// SetDescription sets the "description" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldDescription, v) + return u +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateDescription() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldDescription) + return u +} + +// ClearDescription clears the value of the "description" field. +func (u *ChannelMonitorRequestTemplateUpsert) ClearDescription() *ChannelMonitorRequestTemplateUpsert { + u.SetNull(channelmonitorrequesttemplate.FieldDescription) + return u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldExtraHeaders, v) + return u +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldExtraHeaders) + return u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldBodyOverrideMode, v) + return u +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverrideMode) + return u +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsert { + u.Set(channelmonitorrequesttemplate.FieldBodyOverride, v) + return u +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsert { + u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverride) + return u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsert) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsert { + u.SetNull(channelmonitorrequesttemplate.FieldBodyOverride) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorRequestTemplateUpsertOne) Ignore() *ChannelMonitorRequestTemplateUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorRequestTemplateUpsertOne) DoNothing() *ChannelMonitorRequestTemplateUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreate.OnConflict +// documentation for more info. +func (u *ChannelMonitorRequestTemplateUpsertOne) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetName(v string) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateName() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateName() + }) +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetProvider(v) + }) +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateProvider() + }) +} + +// SetDescription sets the "description" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateDescription() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) ClearDescription() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.ClearDescription() + }) +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetExtraHeaders(v) + }) +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateExtraHeaders() + }) +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetBodyOverrideMode(v) + }) +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateBodyOverrideMode() + }) +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetBodyOverride(v) + }) +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateBodyOverride() + }) +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsertOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertOne { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.ClearBodyOverride() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorRequestTemplateUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorRequestTemplateUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ChannelMonitorRequestTemplateUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ChannelMonitorRequestTemplateUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ChannelMonitorRequestTemplateCreateBulk is the builder for creating many ChannelMonitorRequestTemplate entities in bulk. +type ChannelMonitorRequestTemplateCreateBulk struct { + config + err error + builders []*ChannelMonitorRequestTemplateCreate + conflict []sql.ConflictOption +} + +// Save creates the ChannelMonitorRequestTemplate entities in the database. +func (_c *ChannelMonitorRequestTemplateCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ChannelMonitorRequestTemplate, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ChannelMonitorRequestTemplateMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ChannelMonitorRequestTemplateCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorRequestTemplate { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ChannelMonitorRequestTemplateCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ChannelMonitorRequestTemplateCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ChannelMonitorRequestTemplate.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertBulk { + _c.conflict = opts + return &ChannelMonitorRequestTemplateUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ChannelMonitorRequestTemplateUpsertBulk{ + create: _c, + } +} + +// ChannelMonitorRequestTemplateUpsertBulk is the builder for "upsert"-ing +// a bulk of ChannelMonitorRequestTemplate nodes. +type ChannelMonitorRequestTemplateUpsertBulk struct { + create *ChannelMonitorRequestTemplateCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ChannelMonitorRequestTemplate.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ChannelMonitorRequestTemplateUpsertBulk) Ignore() *ChannelMonitorRequestTemplateUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ChannelMonitorRequestTemplateUpsertBulk) DoNothing() *ChannelMonitorRequestTemplateUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreateBulk.OnConflict +// documentation for more info. +func (u *ChannelMonitorRequestTemplateUpsertBulk) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetName(v string) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateName() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateName() + }) +} + +// SetProvider sets the "provider" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetProvider(v) + }) +} + +// UpdateProvider sets the "provider" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateProvider() + }) +} + +// SetDescription sets the "description" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateDescription() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearDescription() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.ClearDescription() + }) +} + +// SetExtraHeaders sets the "extra_headers" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetExtraHeaders(v) + }) +} + +// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateExtraHeaders() + }) +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetBodyOverrideMode(v) + }) +} + +// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateBodyOverrideMode() + }) +} + +// SetBodyOverride sets the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.SetBodyOverride(v) + }) +} + +// UpdateBodyOverride sets the "body_override" field to the value that was provided on create. +func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.UpdateBodyOverride() + }) +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk { + return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) { + s.ClearBodyOverride() + }) +} + +// Exec executes the query. +func (u *ChannelMonitorRequestTemplateUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorRequestTemplateCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ChannelMonitorRequestTemplateUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitorrequesttemplate_delete.go b/backend/ent/channelmonitorrequesttemplate_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..98d365c8ed1742a88cc58177cf33f1045f6b95bb --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorRequestTemplateDelete is the builder for deleting a ChannelMonitorRequestTemplate entity. +type ChannelMonitorRequestTemplateDelete struct { + config + hooks []Hook + mutation *ChannelMonitorRequestTemplateMutation +} + +// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder. +func (_d *ChannelMonitorRequestTemplateDelete) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ChannelMonitorRequestTemplateDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorRequestTemplateDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ChannelMonitorRequestTemplateDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ChannelMonitorRequestTemplateDeleteOne is the builder for deleting a single ChannelMonitorRequestTemplate entity. +type ChannelMonitorRequestTemplateDeleteOne struct { + _d *ChannelMonitorRequestTemplateDelete +} + +// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder. +func (_d *ChannelMonitorRequestTemplateDeleteOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ChannelMonitorRequestTemplateDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{channelmonitorrequesttemplate.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ChannelMonitorRequestTemplateDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/channelmonitorrequesttemplate_query.go b/backend/ent/channelmonitorrequesttemplate_query.go new file mode 100644 index 0000000000000000000000000000000000000000..6491ea608e16f7743bdab34bc8f5e69e877eab84 --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate_query.go @@ -0,0 +1,648 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorRequestTemplateQuery is the builder for querying ChannelMonitorRequestTemplate entities. +type ChannelMonitorRequestTemplateQuery struct { + config + ctx *QueryContext + order []channelmonitorrequesttemplate.OrderOption + inters []Interceptor + predicates []predicate.ChannelMonitorRequestTemplate + withMonitors *ChannelMonitorQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ChannelMonitorRequestTemplateQuery builder. +func (_q *ChannelMonitorRequestTemplateQuery) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ChannelMonitorRequestTemplateQuery) Limit(limit int) *ChannelMonitorRequestTemplateQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ChannelMonitorRequestTemplateQuery) Offset(offset int) *ChannelMonitorRequestTemplateQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ChannelMonitorRequestTemplateQuery) Unique(unique bool) *ChannelMonitorRequestTemplateQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ChannelMonitorRequestTemplateQuery) Order(o ...channelmonitorrequesttemplate.OrderOption) *ChannelMonitorRequestTemplateQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryMonitors chains the current query on the "monitors" edge. +func (_q *ChannelMonitorRequestTemplateQuery) QueryMonitors() *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, selector), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first ChannelMonitorRequestTemplate entity from the query. +// Returns a *NotFoundError when no ChannelMonitorRequestTemplate was found. +func (_q *ChannelMonitorRequestTemplateQuery) First(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{channelmonitorrequesttemplate.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) FirstX(ctx context.Context) *ChannelMonitorRequestTemplate { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ChannelMonitorRequestTemplate ID from the query. +// Returns a *NotFoundError when no ChannelMonitorRequestTemplate ID was found. +func (_q *ChannelMonitorRequestTemplateQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{channelmonitorrequesttemplate.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ChannelMonitorRequestTemplate entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate entity is found. +// Returns a *NotFoundError when no ChannelMonitorRequestTemplate entities are found. +func (_q *ChannelMonitorRequestTemplateQuery) Only(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{channelmonitorrequesttemplate.Label} + default: + return nil, &NotSingularError{channelmonitorrequesttemplate.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) OnlyX(ctx context.Context) *ChannelMonitorRequestTemplate { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ChannelMonitorRequestTemplate ID in the query. +// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ChannelMonitorRequestTemplateQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{channelmonitorrequesttemplate.Label} + default: + err = &NotSingularError{channelmonitorrequesttemplate.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ChannelMonitorRequestTemplates. +func (_q *ChannelMonitorRequestTemplateQuery) All(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ChannelMonitorRequestTemplate, *ChannelMonitorRequestTemplateQuery]() + return withInterceptors[[]*ChannelMonitorRequestTemplate](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) AllX(ctx context.Context) []*ChannelMonitorRequestTemplate { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ChannelMonitorRequestTemplate IDs. +func (_q *ChannelMonitorRequestTemplateQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(channelmonitorrequesttemplate.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ChannelMonitorRequestTemplateQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorRequestTemplateQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ChannelMonitorRequestTemplateQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ChannelMonitorRequestTemplateQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ChannelMonitorRequestTemplateQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ChannelMonitorRequestTemplateQuery) Clone() *ChannelMonitorRequestTemplateQuery { + if _q == nil { + return nil + } + return &ChannelMonitorRequestTemplateQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]channelmonitorrequesttemplate.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ChannelMonitorRequestTemplate{}, _q.predicates...), + withMonitors: _q.withMonitors.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithMonitors tells the query-builder to eager-load the nodes that are connected to +// the "monitors" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ChannelMonitorRequestTemplateQuery) WithMonitors(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorRequestTemplateQuery { + query := (&ChannelMonitorClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withMonitors = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ChannelMonitorRequestTemplate.Query(). +// GroupBy(channelmonitorrequesttemplate.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ChannelMonitorRequestTemplateQuery) GroupBy(field string, fields ...string) *ChannelMonitorRequestTemplateGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ChannelMonitorRequestTemplateGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = channelmonitorrequesttemplate.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.ChannelMonitorRequestTemplate.Query(). +// Select(channelmonitorrequesttemplate.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *ChannelMonitorRequestTemplateQuery) Select(fields ...string) *ChannelMonitorRequestTemplateSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ChannelMonitorRequestTemplateSelect{ChannelMonitorRequestTemplateQuery: _q} + sbuild.label = channelmonitorrequesttemplate.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ChannelMonitorRequestTemplateSelect configured with the given aggregations. +func (_q *ChannelMonitorRequestTemplateQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ChannelMonitorRequestTemplateQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !channelmonitorrequesttemplate.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ChannelMonitorRequestTemplateQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorRequestTemplate, error) { + var ( + nodes = []*ChannelMonitorRequestTemplate{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withMonitors != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ChannelMonitorRequestTemplate).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ChannelMonitorRequestTemplate{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withMonitors; query != nil { + if err := _q.loadMonitors(ctx, query, nodes, + func(n *ChannelMonitorRequestTemplate) { n.Edges.Monitors = []*ChannelMonitor{} }, + func(n *ChannelMonitorRequestTemplate, e *ChannelMonitor) { + n.Edges.Monitors = append(n.Edges.Monitors, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *ChannelMonitorRequestTemplateQuery) loadMonitors(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorRequestTemplate, init func(*ChannelMonitorRequestTemplate), assign func(*ChannelMonitorRequestTemplate, *ChannelMonitor)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*ChannelMonitorRequestTemplate) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(channelmonitor.FieldTemplateID) + } + query.Where(predicate.ChannelMonitor(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(channelmonitorrequesttemplate.MonitorsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TemplateID + if fk == nil { + return fmt.Errorf(`foreign-key "template_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "template_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *ChannelMonitorRequestTemplateQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ChannelMonitorRequestTemplateQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID) + for i := range fields { + if fields[i] != channelmonitorrequesttemplate.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ChannelMonitorRequestTemplateQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(channelmonitorrequesttemplate.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = channelmonitorrequesttemplate.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ChannelMonitorRequestTemplateQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ChannelMonitorRequestTemplateQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ChannelMonitorRequestTemplateGroupBy is the group-by builder for ChannelMonitorRequestTemplate entities. +type ChannelMonitorRequestTemplateGroupBy struct { + selector + build *ChannelMonitorRequestTemplateQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ChannelMonitorRequestTemplateGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ChannelMonitorRequestTemplateGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ChannelMonitorRequestTemplateGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ChannelMonitorRequestTemplateSelect is the builder for selecting fields of ChannelMonitorRequestTemplate entities. +type ChannelMonitorRequestTemplateSelect struct { + *ChannelMonitorRequestTemplateQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ChannelMonitorRequestTemplateSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ChannelMonitorRequestTemplateSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateSelect](ctx, _s.ChannelMonitorRequestTemplateQuery, _s, _s.inters, v) +} + +func (_s *ChannelMonitorRequestTemplateSelect) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/channelmonitorrequesttemplate_update.go b/backend/ent/channelmonitorrequesttemplate_update.go new file mode 100644 index 0000000000000000000000000000000000000000..8f55ba041d0c73381ad75804994c44110b3b00ad --- /dev/null +++ b/backend/ent/channelmonitorrequesttemplate_update.go @@ -0,0 +1,639 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ChannelMonitorRequestTemplateUpdate is the builder for updating ChannelMonitorRequestTemplate entities. +type ChannelMonitorRequestTemplateUpdate struct { + config + hooks []Hook + mutation *ChannelMonitorRequestTemplateMutation +} + +// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder. +func (_u *ChannelMonitorRequestTemplateUpdate) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetName(v string) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetProvider sets the "provider" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetProvider(v) + return _u +} + +// SetNillableProvider sets the "provider" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate { + if v != nil { + _u.SetProvider(*v) + } + return _u +} + +// SetDescription sets the "description" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdate { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ChannelMonitorRequestTemplateUpdate) ClearDescription() *ChannelMonitorRequestTemplateUpdate { + _u.mutation.ClearDescription() + return _u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetExtraHeaders(v) + return _u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetBodyOverrideMode(v) + return _u +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdate { + if v != nil { + _u.SetBodyOverrideMode(*v) + } + return _u +} + +// SetBodyOverride sets the "body_override" field. +func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.SetBodyOverride(v) + return _u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (_u *ChannelMonitorRequestTemplateUpdate) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdate { + _u.mutation.ClearBodyOverride() + return _u +} + +// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs. +func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.AddMonitorIDs(ids...) + return _u +} + +// AddMonitors adds the "monitors" edges to the ChannelMonitor entity. +func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddMonitorIDs(ids...) +} + +// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder. +func (_u *ChannelMonitorRequestTemplateUpdate) Mutation() *ChannelMonitorRequestTemplateMutation { + return _u.mutation +} + +// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity. +func (_u *ChannelMonitorRequestTemplateUpdate) ClearMonitors() *ChannelMonitorRequestTemplateUpdate { + _u.mutation.ClearMonitors() + return _u +} + +// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs. +func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate { + _u.mutation.RemoveMonitorIDs(ids...) + return _u +} + +// RemoveMonitors removes "monitors" edges to ChannelMonitor entities. +func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveMonitorIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ChannelMonitorRequestTemplateUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorRequestTemplateUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ChannelMonitorRequestTemplateUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorRequestTemplateUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorRequestTemplateUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorRequestTemplateUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := channelmonitorrequesttemplate.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)} + } + } + if v, ok := _u.mutation.Provider(); ok { + if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)} + } + } + if v, ok := _u.mutation.Description(); ok { + if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil { + return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)} + } + } + if v, ok := _u.mutation.BodyOverrideMode(); ok { + if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Provider(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString) + } + if value, ok := _u.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value) + } + if value, ok := _u.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value) + } + if value, ok := _u.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value) + } + if _u.mutation.BodyOverrideCleared() { + _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON) + } + if _u.mutation.MonitorsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitorrequesttemplate.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ChannelMonitorRequestTemplateUpdateOne is the builder for updating a single ChannelMonitorRequestTemplate entity. +type ChannelMonitorRequestTemplateUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ChannelMonitorRequestTemplateMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetName(v string) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetProvider sets the "provider" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetProvider(v) + return _u +} + +// SetNillableProvider sets the "provider" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne { + if v != nil { + _u.SetProvider(*v) + } + return _u +} + +// SetDescription sets the "description" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdateOne { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearDescription() *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.ClearDescription() + return _u +} + +// SetExtraHeaders sets the "extra_headers" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetExtraHeaders(v) + return _u +} + +// SetBodyOverrideMode sets the "body_override_mode" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetBodyOverrideMode(v) + return _u +} + +// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdateOne { + if v != nil { + _u.SetBodyOverrideMode(*v) + } + return _u +} + +// SetBodyOverride sets the "body_override" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.SetBodyOverride(v) + return _u +} + +// ClearBodyOverride clears the value of the "body_override" field. +func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.ClearBodyOverride() + return _u +} + +// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs. +func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.AddMonitorIDs(ids...) + return _u +} + +// AddMonitors adds the "monitors" edges to the ChannelMonitor entity. +func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddMonitorIDs(ids...) +} + +// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder. +func (_u *ChannelMonitorRequestTemplateUpdateOne) Mutation() *ChannelMonitorRequestTemplateMutation { + return _u.mutation +} + +// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity. +func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearMonitors() *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.ClearMonitors() + return _u +} + +// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs. +func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.RemoveMonitorIDs(ids...) + return _u +} + +// RemoveMonitors removes "monitors" edges to ChannelMonitor entities. +func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveMonitorIDs(ids...) +} + +// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder. +func (_u *ChannelMonitorRequestTemplateUpdateOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ChannelMonitorRequestTemplateUpdateOne) Select(field string, fields ...string) *ChannelMonitorRequestTemplateUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ChannelMonitorRequestTemplate entity. +func (_u *ChannelMonitorRequestTemplateUpdateOne) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ChannelMonitorRequestTemplateUpdateOne) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ChannelMonitorRequestTemplateUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ChannelMonitorRequestTemplateUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ChannelMonitorRequestTemplateUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := channelmonitorrequesttemplate.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)} + } + } + if v, ok := _u.mutation.Provider(); ok { + if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil { + return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)} + } + } + if v, ok := _u.mutation.Description(); ok { + if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil { + return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)} + } + } + if v, ok := _u.mutation.BodyOverrideMode(); ok { + if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil { + return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)} + } + } + return nil +} + +func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorRequestTemplate, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorRequestTemplate.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID) + for _, f := range fields { + if !channelmonitorrequesttemplate.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != channelmonitorrequesttemplate.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Provider(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString) + } + if value, ok := _u.mutation.ExtraHeaders(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value) + } + if value, ok := _u.mutation.BodyOverrideMode(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value) + } + if value, ok := _u.mutation.BodyOverride(); ok { + _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value) + } + if _u.mutation.BodyOverrideCleared() { + _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON) + } + if _u.mutation.MonitorsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: channelmonitorrequesttemplate.MonitorsTable, + Columns: []string{channelmonitorrequesttemplate.MonitorsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &ChannelMonitorRequestTemplate{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{channelmonitorrequesttemplate.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015ad8b1fe576a57f6298562a5619ab3e4a8..df20ddfa34192f79b400d7b675bada312cdd2b8d 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,12 +20,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -60,18 +68,34 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient + // ChannelMonitor is the client for interacting with the ChannelMonitor builders. + ChannelMonitor *ChannelMonitorClient + // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders. + ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient + // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders. + ChannelMonitorHistory *ChannelMonitorHistoryClient + // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders. + ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -118,12 +142,20 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.AuthIdentity = NewAuthIdentityClient(c.config) + c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config) + c.ChannelMonitor = NewChannelMonitorClient(c.config) + c.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(c.config) + c.ChannelMonitorHistory = NewChannelMonitorHistoryClient(c.config) + c.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) + c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config) c.PaymentAuditLog = NewPaymentAuditLogClient(c.config) c.PaymentOrder = NewPaymentOrderClient(c.config) c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config) + c.PendingAuthSession = NewPendingAuthSessionClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -229,34 +261,42 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ChannelMonitor: NewChannelMonitorClient(cfg), + ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg), + ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg), + ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -274,34 +314,42 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ChannelMonitor: NewChannelMonitorClient(cfg), + ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg), + ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg), + ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -332,11 +380,14 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor, + c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, + c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -348,11 +399,14 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor, + c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory, + c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -372,18 +426,34 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *AuthIdentityMutation: + return c.AuthIdentity.mutate(ctx, m) + case *AuthIdentityChannelMutation: + return c.AuthIdentityChannel.mutate(ctx, m) + case *ChannelMonitorMutation: + return c.ChannelMonitor.mutate(ctx, m) + case *ChannelMonitorDailyRollupMutation: + return c.ChannelMonitorDailyRollup.mutate(ctx, m) + case *ChannelMonitorHistoryMutation: + return c.ChannelMonitorHistory.mutate(ctx, m) + case *ChannelMonitorRequestTemplateMutation: + return c.ChannelMonitorRequestTemplate.mutate(ctx, m) case *ErrorPassthroughRuleMutation: return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *IdempotencyRecordMutation: return c.IdempotencyRecord.mutate(ctx, m) + case *IdentityAdoptionDecisionMutation: + return c.IdentityAdoptionDecision.mutate(ctx, m) case *PaymentAuditLogMutation: return c.PaymentAuditLog.mutate(ctx, m) case *PaymentOrderMutation: return c.PaymentOrder.mutate(ctx, m) case *PaymentProviderInstanceMutation: return c.PaymentProviderInstance.mutate(ctx, m) + case *PendingAuthSessionMutation: + return c.PendingAuthSession.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1231,6 +1301,964 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// AuthIdentityClient is a client for the AuthIdentity schema. +type AuthIdentityClient struct { + config +} + +// NewAuthIdentityClient returns a client for the AuthIdentity from the given config. +func NewAuthIdentityClient(c config) *AuthIdentityClient { + return &AuthIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`. +func (c *AuthIdentityClient) Use(hooks ...Hook) { + c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`. +func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...) +} + +// Create returns a builder for creating a AuthIdentity entity. +func (c *AuthIdentityClient) Create() *AuthIdentityCreate { + mutation := newAuthIdentityMutation(c.config, OpCreate) + return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentity entities. +func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk { + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentity. +func (c *AuthIdentityClient) Update() *AuthIdentityUpdate { + mutation := newAuthIdentityMutation(c.config, OpUpdate) + return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentity. +func (c *AuthIdentityClient) Delete() *AuthIdentityDelete { + mutation := newAuthIdentityMutation(c.config, OpDelete) + return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne { + builder := c.Delete().Where(authidentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentity. +func (c *AuthIdentityClient) Query() *AuthIdentityQuery { + return &AuthIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentity entity by its id. +func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) { + return c.Query().Where(authidentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChannels queries the channels edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityClient) Hooks() []Hook { + return c.hooks.AuthIdentity +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityClient) Interceptors() []Interceptor { + return c.inters.AuthIdentity +} + +func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op()) + } +} + +// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema. +type AuthIdentityChannelClient struct { + config +} + +// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config. +func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient { + return &AuthIdentityChannelClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`. +func (c *AuthIdentityChannelClient) Use(hooks ...Hook) { + c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`. +func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...) +} + +// Create returns a builder for creating a AuthIdentityChannel entity. +func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate { + mutation := newAuthIdentityChannelMutation(c.config, OpCreate) + return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities. +func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk { + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityChannelCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdate) + return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete { + mutation := newAuthIdentityChannelMutation(c.config, OpDelete) + return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne { + builder := c.Delete().Where(authidentitychannel.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityChannelDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery { + return &AuthIdentityChannelQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentityChannel}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentityChannel entity by its id. +func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) { + return c.Query().Where(authidentitychannel.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryIdentity queries the identity edge of a AuthIdentityChannel. +func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityChannelClient) Hooks() []Hook { + return c.hooks.AuthIdentityChannel +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityChannelClient) Interceptors() []Interceptor { + return c.inters.AuthIdentityChannel +} + +func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op()) + } +} + +// ChannelMonitorClient is a client for the ChannelMonitor schema. +type ChannelMonitorClient struct { + config +} + +// NewChannelMonitorClient returns a client for the ChannelMonitor from the given config. +func NewChannelMonitorClient(c config) *ChannelMonitorClient { + return &ChannelMonitorClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `channelmonitor.Hooks(f(g(h())))`. +func (c *ChannelMonitorClient) Use(hooks ...Hook) { + c.hooks.ChannelMonitor = append(c.hooks.ChannelMonitor, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `channelmonitor.Intercept(f(g(h())))`. +func (c *ChannelMonitorClient) Intercept(interceptors ...Interceptor) { + c.inters.ChannelMonitor = append(c.inters.ChannelMonitor, interceptors...) +} + +// Create returns a builder for creating a ChannelMonitor entity. +func (c *ChannelMonitorClient) Create() *ChannelMonitorCreate { + mutation := newChannelMonitorMutation(c.config, OpCreate) + return &ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ChannelMonitor entities. +func (c *ChannelMonitorClient) CreateBulk(builders ...*ChannelMonitorCreate) *ChannelMonitorCreateBulk { + return &ChannelMonitorCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ChannelMonitorClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorCreate, int)) *ChannelMonitorCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ChannelMonitorCreateBulk{err: fmt.Errorf("calling to ChannelMonitorClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ChannelMonitorCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ChannelMonitorCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ChannelMonitor. +func (c *ChannelMonitorClient) Update() *ChannelMonitorUpdate { + mutation := newChannelMonitorMutation(c.config, OpUpdate) + return &ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ChannelMonitorClient) UpdateOne(_m *ChannelMonitor) *ChannelMonitorUpdateOne { + mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitor(_m)) + return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ChannelMonitorClient) UpdateOneID(id int64) *ChannelMonitorUpdateOne { + mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitorID(id)) + return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ChannelMonitor. +func (c *ChannelMonitorClient) Delete() *ChannelMonitorDelete { + mutation := newChannelMonitorMutation(c.config, OpDelete) + return &ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ChannelMonitorClient) DeleteOne(_m *ChannelMonitor) *ChannelMonitorDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ChannelMonitorClient) DeleteOneID(id int64) *ChannelMonitorDeleteOne { + builder := c.Delete().Where(channelmonitor.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ChannelMonitorDeleteOne{builder} +} + +// Query returns a query builder for ChannelMonitor. +func (c *ChannelMonitorClient) Query() *ChannelMonitorQuery { + return &ChannelMonitorQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeChannelMonitor}, + inters: c.Interceptors(), + } +} + +// Get returns a ChannelMonitor entity by its id. +func (c *ChannelMonitorClient) Get(ctx context.Context, id int64) (*ChannelMonitor, error) { + return c.Query().Where(channelmonitor.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ChannelMonitorClient) GetX(ctx context.Context, id int64) *ChannelMonitor { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryHistory queries the history edge of a ChannelMonitor. +func (c *ChannelMonitorClient) QueryHistory(_m *ChannelMonitor) *ChannelMonitorHistoryQuery { + query := (&ChannelMonitorHistoryClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id), + sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryDailyRollups queries the daily_rollups edge of a ChannelMonitor. +func (c *ChannelMonitorClient) QueryDailyRollups(_m *ChannelMonitor) *ChannelMonitorDailyRollupQuery { + query := (&ChannelMonitorDailyRollupClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id), + sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryRequestTemplate queries the request_template edge of a ChannelMonitor. +func (c *ChannelMonitorClient) QueryRequestTemplate(_m *ChannelMonitor) *ChannelMonitorRequestTemplateQuery { + query := (&ChannelMonitorRequestTemplateClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id), + sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *ChannelMonitorClient) Hooks() []Hook { + return c.hooks.ChannelMonitor +} + +// Interceptors returns the client interceptors. +func (c *ChannelMonitorClient) Interceptors() []Interceptor { + return c.inters.ChannelMonitor +} + +func (c *ChannelMonitorClient) mutate(ctx context.Context, m *ChannelMonitorMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ChannelMonitor mutation op: %q", m.Op()) + } +} + +// ChannelMonitorDailyRollupClient is a client for the ChannelMonitorDailyRollup schema. +type ChannelMonitorDailyRollupClient struct { + config +} + +// NewChannelMonitorDailyRollupClient returns a client for the ChannelMonitorDailyRollup from the given config. +func NewChannelMonitorDailyRollupClient(c config) *ChannelMonitorDailyRollupClient { + return &ChannelMonitorDailyRollupClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `channelmonitordailyrollup.Hooks(f(g(h())))`. +func (c *ChannelMonitorDailyRollupClient) Use(hooks ...Hook) { + c.hooks.ChannelMonitorDailyRollup = append(c.hooks.ChannelMonitorDailyRollup, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `channelmonitordailyrollup.Intercept(f(g(h())))`. +func (c *ChannelMonitorDailyRollupClient) Intercept(interceptors ...Interceptor) { + c.inters.ChannelMonitorDailyRollup = append(c.inters.ChannelMonitorDailyRollup, interceptors...) +} + +// Create returns a builder for creating a ChannelMonitorDailyRollup entity. +func (c *ChannelMonitorDailyRollupClient) Create() *ChannelMonitorDailyRollupCreate { + mutation := newChannelMonitorDailyRollupMutation(c.config, OpCreate) + return &ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ChannelMonitorDailyRollup entities. +func (c *ChannelMonitorDailyRollupClient) CreateBulk(builders ...*ChannelMonitorDailyRollupCreate) *ChannelMonitorDailyRollupCreateBulk { + return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ChannelMonitorDailyRollupClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorDailyRollupCreate, int)) *ChannelMonitorDailyRollupCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ChannelMonitorDailyRollupCreateBulk{err: fmt.Errorf("calling to ChannelMonitorDailyRollupClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ChannelMonitorDailyRollupCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ChannelMonitorDailyRollup. +func (c *ChannelMonitorDailyRollupClient) Update() *ChannelMonitorDailyRollupUpdate { + mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdate) + return &ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ChannelMonitorDailyRollupClient) UpdateOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne { + mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollup(_m)) + return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ChannelMonitorDailyRollupClient) UpdateOneID(id int64) *ChannelMonitorDailyRollupUpdateOne { + mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollupID(id)) + return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ChannelMonitorDailyRollup. +func (c *ChannelMonitorDailyRollupClient) Delete() *ChannelMonitorDailyRollupDelete { + mutation := newChannelMonitorDailyRollupMutation(c.config, OpDelete) + return &ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ChannelMonitorDailyRollupClient) DeleteOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ChannelMonitorDailyRollupClient) DeleteOneID(id int64) *ChannelMonitorDailyRollupDeleteOne { + builder := c.Delete().Where(channelmonitordailyrollup.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ChannelMonitorDailyRollupDeleteOne{builder} +} + +// Query returns a query builder for ChannelMonitorDailyRollup. +func (c *ChannelMonitorDailyRollupClient) Query() *ChannelMonitorDailyRollupQuery { + return &ChannelMonitorDailyRollupQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeChannelMonitorDailyRollup}, + inters: c.Interceptors(), + } +} + +// Get returns a ChannelMonitorDailyRollup entity by its id. +func (c *ChannelMonitorDailyRollupClient) Get(ctx context.Context, id int64) (*ChannelMonitorDailyRollup, error) { + return c.Query().Where(channelmonitordailyrollup.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ChannelMonitorDailyRollupClient) GetX(ctx context.Context, id int64) *ChannelMonitorDailyRollup { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryMonitor queries the monitor edge of a ChannelMonitorDailyRollup. +func (c *ChannelMonitorDailyRollupClient) QueryMonitor(_m *ChannelMonitorDailyRollup) *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, id), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *ChannelMonitorDailyRollupClient) Hooks() []Hook { + return c.hooks.ChannelMonitorDailyRollup +} + +// Interceptors returns the client interceptors. +func (c *ChannelMonitorDailyRollupClient) Interceptors() []Interceptor { + return c.inters.ChannelMonitorDailyRollup +} + +func (c *ChannelMonitorDailyRollupClient) mutate(ctx context.Context, m *ChannelMonitorDailyRollupMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ChannelMonitorDailyRollup mutation op: %q", m.Op()) + } +} + +// ChannelMonitorHistoryClient is a client for the ChannelMonitorHistory schema. +type ChannelMonitorHistoryClient struct { + config +} + +// NewChannelMonitorHistoryClient returns a client for the ChannelMonitorHistory from the given config. +func NewChannelMonitorHistoryClient(c config) *ChannelMonitorHistoryClient { + return &ChannelMonitorHistoryClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `channelmonitorhistory.Hooks(f(g(h())))`. +func (c *ChannelMonitorHistoryClient) Use(hooks ...Hook) { + c.hooks.ChannelMonitorHistory = append(c.hooks.ChannelMonitorHistory, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `channelmonitorhistory.Intercept(f(g(h())))`. +func (c *ChannelMonitorHistoryClient) Intercept(interceptors ...Interceptor) { + c.inters.ChannelMonitorHistory = append(c.inters.ChannelMonitorHistory, interceptors...) +} + +// Create returns a builder for creating a ChannelMonitorHistory entity. +func (c *ChannelMonitorHistoryClient) Create() *ChannelMonitorHistoryCreate { + mutation := newChannelMonitorHistoryMutation(c.config, OpCreate) + return &ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ChannelMonitorHistory entities. +func (c *ChannelMonitorHistoryClient) CreateBulk(builders ...*ChannelMonitorHistoryCreate) *ChannelMonitorHistoryCreateBulk { + return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ChannelMonitorHistoryClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorHistoryCreate, int)) *ChannelMonitorHistoryCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ChannelMonitorHistoryCreateBulk{err: fmt.Errorf("calling to ChannelMonitorHistoryClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ChannelMonitorHistoryCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ChannelMonitorHistory. +func (c *ChannelMonitorHistoryClient) Update() *ChannelMonitorHistoryUpdate { + mutation := newChannelMonitorHistoryMutation(c.config, OpUpdate) + return &ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ChannelMonitorHistoryClient) UpdateOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne { + mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistory(_m)) + return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ChannelMonitorHistoryClient) UpdateOneID(id int64) *ChannelMonitorHistoryUpdateOne { + mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistoryID(id)) + return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ChannelMonitorHistory. +func (c *ChannelMonitorHistoryClient) Delete() *ChannelMonitorHistoryDelete { + mutation := newChannelMonitorHistoryMutation(c.config, OpDelete) + return &ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ChannelMonitorHistoryClient) DeleteOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ChannelMonitorHistoryClient) DeleteOneID(id int64) *ChannelMonitorHistoryDeleteOne { + builder := c.Delete().Where(channelmonitorhistory.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ChannelMonitorHistoryDeleteOne{builder} +} + +// Query returns a query builder for ChannelMonitorHistory. +func (c *ChannelMonitorHistoryClient) Query() *ChannelMonitorHistoryQuery { + return &ChannelMonitorHistoryQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeChannelMonitorHistory}, + inters: c.Interceptors(), + } +} + +// Get returns a ChannelMonitorHistory entity by its id. +func (c *ChannelMonitorHistoryClient) Get(ctx context.Context, id int64) (*ChannelMonitorHistory, error) { + return c.Query().Where(channelmonitorhistory.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ChannelMonitorHistoryClient) GetX(ctx context.Context, id int64) *ChannelMonitorHistory { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryMonitor queries the monitor edge of a ChannelMonitorHistory. +func (c *ChannelMonitorHistoryClient) QueryMonitor(_m *ChannelMonitorHistory) *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, id), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *ChannelMonitorHistoryClient) Hooks() []Hook { + return c.hooks.ChannelMonitorHistory +} + +// Interceptors returns the client interceptors. +func (c *ChannelMonitorHistoryClient) Interceptors() []Interceptor { + return c.inters.ChannelMonitorHistory +} + +func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMonitorHistoryMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ChannelMonitorHistory mutation op: %q", m.Op()) + } +} + +// ChannelMonitorRequestTemplateClient is a client for the ChannelMonitorRequestTemplate schema. +type ChannelMonitorRequestTemplateClient struct { + config +} + +// NewChannelMonitorRequestTemplateClient returns a client for the ChannelMonitorRequestTemplate from the given config. +func NewChannelMonitorRequestTemplateClient(c config) *ChannelMonitorRequestTemplateClient { + return &ChannelMonitorRequestTemplateClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `channelmonitorrequesttemplate.Hooks(f(g(h())))`. +func (c *ChannelMonitorRequestTemplateClient) Use(hooks ...Hook) { + c.hooks.ChannelMonitorRequestTemplate = append(c.hooks.ChannelMonitorRequestTemplate, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `channelmonitorrequesttemplate.Intercept(f(g(h())))`. +func (c *ChannelMonitorRequestTemplateClient) Intercept(interceptors ...Interceptor) { + c.inters.ChannelMonitorRequestTemplate = append(c.inters.ChannelMonitorRequestTemplate, interceptors...) +} + +// Create returns a builder for creating a ChannelMonitorRequestTemplate entity. +func (c *ChannelMonitorRequestTemplateClient) Create() *ChannelMonitorRequestTemplateCreate { + mutation := newChannelMonitorRequestTemplateMutation(c.config, OpCreate) + return &ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ChannelMonitorRequestTemplate entities. +func (c *ChannelMonitorRequestTemplateClient) CreateBulk(builders ...*ChannelMonitorRequestTemplateCreate) *ChannelMonitorRequestTemplateCreateBulk { + return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ChannelMonitorRequestTemplateClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorRequestTemplateCreate, int)) *ChannelMonitorRequestTemplateCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ChannelMonitorRequestTemplateCreateBulk{err: fmt.Errorf("calling to ChannelMonitorRequestTemplateClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ChannelMonitorRequestTemplateCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ChannelMonitorRequestTemplate. +func (c *ChannelMonitorRequestTemplateClient) Update() *ChannelMonitorRequestTemplateUpdate { + mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdate) + return &ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ChannelMonitorRequestTemplateClient) UpdateOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne { + mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplate(_m)) + return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ChannelMonitorRequestTemplateClient) UpdateOneID(id int64) *ChannelMonitorRequestTemplateUpdateOne { + mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplateID(id)) + return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ChannelMonitorRequestTemplate. +func (c *ChannelMonitorRequestTemplateClient) Delete() *ChannelMonitorRequestTemplateDelete { + mutation := newChannelMonitorRequestTemplateMutation(c.config, OpDelete) + return &ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ChannelMonitorRequestTemplateClient) DeleteOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ChannelMonitorRequestTemplateClient) DeleteOneID(id int64) *ChannelMonitorRequestTemplateDeleteOne { + builder := c.Delete().Where(channelmonitorrequesttemplate.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ChannelMonitorRequestTemplateDeleteOne{builder} +} + +// Query returns a query builder for ChannelMonitorRequestTemplate. +func (c *ChannelMonitorRequestTemplateClient) Query() *ChannelMonitorRequestTemplateQuery { + return &ChannelMonitorRequestTemplateQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeChannelMonitorRequestTemplate}, + inters: c.Interceptors(), + } +} + +// Get returns a ChannelMonitorRequestTemplate entity by its id. +func (c *ChannelMonitorRequestTemplateClient) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) { + return c.Query().Where(channelmonitorrequesttemplate.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ChannelMonitorRequestTemplateClient) GetX(ctx context.Context, id int64) *ChannelMonitorRequestTemplate { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryMonitors queries the monitors edge of a ChannelMonitorRequestTemplate. +func (c *ChannelMonitorRequestTemplateClient) QueryMonitors(_m *ChannelMonitorRequestTemplate) *ChannelMonitorQuery { + query := (&ChannelMonitorClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, id), + sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *ChannelMonitorRequestTemplateClient) Hooks() []Hook { + return c.hooks.ChannelMonitorRequestTemplate +} + +// Interceptors returns the client interceptors. +func (c *ChannelMonitorRequestTemplateClient) Interceptors() []Interceptor { + return c.inters.ChannelMonitorRequestTemplate +} + +func (c *ChannelMonitorRequestTemplateClient) mutate(ctx context.Context, m *ChannelMonitorRequestTemplateMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ChannelMonitorRequestTemplate mutation op: %q", m.Op()) + } +} + // ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. type ErrorPassthroughRuleClient struct { config @@ -1760,6 +2788,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco } } +// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecisionClient struct { + config +} + +// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config. +func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient { + return &IdentityAdoptionDecisionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) { + c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...) +} + +// Create returns a builder for creating a IdentityAdoptionDecision entity. +func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate) + return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities. +func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk { + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdentityAdoptionDecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate) + return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete) + return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne { + builder := c.Delete().Where(identityadoptiondecision.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdentityAdoptionDecisionDeleteOne{builder} +} + +// Query returns a query builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery { + return &IdentityAdoptionDecisionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdentityAdoptionDecision}, + inters: c.Interceptors(), + } +} + +// Get returns a IdentityAdoptionDecision entity by its id. +func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) { + return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryIdentity queries the identity edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *IdentityAdoptionDecisionClient) Hooks() []Hook { + return c.hooks.IdentityAdoptionDecision +} + +// Interceptors returns the client interceptors. +func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor { + return c.inters.IdentityAdoptionDecision +} + +func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op()) + } +} + // PaymentAuditLogClient is a client for the PaymentAuditLog schema. type PaymentAuditLogClient struct { config @@ -2175,6 +3368,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr } } +// PendingAuthSessionClient is a client for the PendingAuthSession schema. +type PendingAuthSessionClient struct { + config +} + +// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config. +func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient { + return &PendingAuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`. +func (c *PendingAuthSessionClient) Use(hooks ...Hook) { + c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`. +func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...) +} + +// Create returns a builder for creating a PendingAuthSession entity. +func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate { + mutation := newPendingAuthSessionMutation(c.config, OpCreate) + return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities. +func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk { + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PendingAuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate { + mutation := newPendingAuthSessionMutation(c.config, OpUpdate) + return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete { + mutation := newPendingAuthSessionMutation(c.config, OpDelete) + return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne { + builder := c.Delete().Where(pendingauthsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PendingAuthSessionDeleteOne{builder} +} + +// Query returns a query builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery { + return &PendingAuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePendingAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a PendingAuthSession entity by its id. +func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) { + return c.Query().Where(pendingauthsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryTargetUser queries the target_user edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PendingAuthSessionClient) Hooks() []Hook { + return c.hooks.PendingAuthSession +} + +// Interceptors returns the client interceptors. +func (c *PendingAuthSessionClient) Interceptors() []Interceptor { + return c.inters.PendingAuthSession +} + +func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3951,6 +5309,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery { return query } +// QueryAuthIdentities queries the auth_identities edge of a User. +func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User. +func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -4628,20 +6018,24 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, - UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Hook + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup, + ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule, + Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog, + PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, + PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, + TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, - UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Interceptor + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup, + ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule, + Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog, + PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, + PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, + TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 96ed5e03a9569a3b5a33922573aa9a2fbf090825..c9fcc314e9522876f4b546e2aeb9cbde9b881048 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,12 +17,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -98,32 +106,40 @@ var ( func checkColumn(t, c string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - apikey.Table: apikey.ValidColumn, - account.Table: account.ValidColumn, - accountgroup.Table: accountgroup.ValidColumn, - announcement.Table: announcement.ValidColumn, - announcementread.Table: announcementread.ValidColumn, - errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, - group.Table: group.ValidColumn, - idempotencyrecord.Table: idempotencyrecord.ValidColumn, - paymentauditlog.Table: paymentauditlog.ValidColumn, - paymentorder.Table: paymentorder.ValidColumn, - paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, - promocode.Table: promocode.ValidColumn, - promocodeusage.Table: promocodeusage.ValidColumn, - proxy.Table: proxy.ValidColumn, - redeemcode.Table: redeemcode.ValidColumn, - securitysecret.Table: securitysecret.ValidColumn, - setting.Table: setting.ValidColumn, - subscriptionplan.Table: subscriptionplan.ValidColumn, - tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, - usagecleanuptask.Table: usagecleanuptask.ValidColumn, - usagelog.Table: usagelog.ValidColumn, - user.Table: user.ValidColumn, - userallowedgroup.Table: userallowedgroup.ValidColumn, - userattributedefinition.Table: userattributedefinition.ValidColumn, - userattributevalue.Table: userattributevalue.ValidColumn, - usersubscription.Table: usersubscription.ValidColumn, + apikey.Table: apikey.ValidColumn, + account.Table: account.ValidColumn, + accountgroup.Table: accountgroup.ValidColumn, + announcement.Table: announcement.ValidColumn, + announcementread.Table: announcementread.ValidColumn, + authidentity.Table: authidentity.ValidColumn, + authidentitychannel.Table: authidentitychannel.ValidColumn, + channelmonitor.Table: channelmonitor.ValidColumn, + channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn, + channelmonitorhistory.Table: channelmonitorhistory.ValidColumn, + channelmonitorrequesttemplate.Table: channelmonitorrequesttemplate.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, + group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, + identityadoptiondecision.Table: identityadoptiondecision.ValidColumn, + paymentauditlog.Table: paymentauditlog.ValidColumn, + paymentorder.Table: paymentorder.ValidColumn, + paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, + pendingauthsession.Table: pendingauthsession.ValidColumn, + promocode.Table: promocode.ValidColumn, + promocodeusage.Table: promocodeusage.ValidColumn, + proxy.Table: proxy.ValidColumn, + redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, + setting.Table: setting.ValidColumn, + subscriptionplan.Table: subscriptionplan.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, + usagelog.Table: usagelog.ValidColumn, + user.Table: user.ValidColumn, + userallowedgroup.Table: userallowedgroup.ValidColumn, + userattributedefinition.Table: userattributedefinition.ValidColumn, + userattributevalue.Table: userattributevalue.ValidColumn, + usersubscription.Table: usersubscription.ValidColumn, }) }) return columnCheck(t, c) diff --git a/backend/ent/group.go b/backend/ent/group.go index f10b50c325e5b7cc507df510b9c222737aa1bc9c..5d9ae2ed2036978d0f77935f51f6d2d2c3b9be8c 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -79,6 +79,8 @@ type Group struct { DefaultMappedModel string `json:"default_mapped_model,omitempty"` // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流 + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) @@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) } } + case group.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -599,6 +607,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("messages_dispatch_model_config=") builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index b1371630dd3caf5ebb806f99dd62e1bd004612b6..24bd9c13d63a1d36b37f01763d85817abb253f11 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -76,6 +76,8 @@ const ( FieldDefaultMappedModel = "default_mapped_model" // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -181,6 +183,7 @@ var Columns = []string{ FieldRequirePrivacySet, FieldDefaultMappedModel, FieldMessagesDispatchModelConfig, + FieldRpmLimit, } var ( @@ -258,6 +261,8 @@ var ( DefaultMappedModelValidator func(string) error // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the Group queries. @@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index cba2ce5f0e499936a957a7c35d11f95628bfd1a1..2814d130f13e7485608f12ce9ee7f8b74e5a79ec 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group { return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index f412fa4070c13875f006c492903e9f7f80ecb62e..20ea0a0fe71ddec6692ece32f2ad44e03537d9b3 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultMessagesDispatchModelConfig _c.mutation.SetMessagesDispatchModelConfig(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := group.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)} + } return nil } @@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) _node.MessagesDispatchModelConfig = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert { + u.Set(group.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert { + u.SetExcluded(group.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert { + u.Add(group.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 7b6d6193256a1baf00559edc0f2381bb1ba26a3a..cc14f897d6257ec6b3e373f0de709e9ab8033c57 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 199dacea0ecc8c8a06ac0e0789e5e968ec63b66e..414eba242c61a23bb6174469b488a69622417e3c 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,78 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary +// function as AuthIdentity mutator. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary +// function as AuthIdentityChannel mutator. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m) +} + +// The ChannelMonitorFunc type is an adapter to allow the use of ordinary +// function as ChannelMonitor mutator. +type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ChannelMonitorFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ChannelMonitorMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorMutation", m) +} + +// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary +// function as ChannelMonitorDailyRollup mutator. +type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ChannelMonitorDailyRollupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ChannelMonitorDailyRollupMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorDailyRollupMutation", m) +} + +// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary +// function as ChannelMonitorHistory mutator. +type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ChannelMonitorHistoryFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ChannelMonitorHistoryMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorHistoryMutation", m) +} + +// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary +// function as ChannelMonitorRequestTemplate mutator. +type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ChannelMonitorRequestTemplateFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ChannelMonitorRequestTemplateMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorRequestTemplateMutation", m) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary // function as ErrorPassthroughRule mutator. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) @@ -105,6 +177,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent. return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary +// function as IdentityAdoptionDecision mutator. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary // function as PaymentAuditLog mutator. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error) @@ -141,6 +225,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary +// function as PendingAuthSession mutator. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PendingAuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go new file mode 100644 index 0000000000000000000000000000000000000000..ecaee65c2fc3b5f5cf3a6df4c3d5455a54691188 --- /dev/null +++ b/backend/ent/identityadoptiondecision.go @@ -0,0 +1,223 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecision struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // PendingAuthSessionID holds the value of the "pending_auth_session_id" field. + PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID *int64 `json:"identity_id,omitempty"` + // AdoptDisplayName holds the value of the "adopt_display_name" field. + AdoptDisplayName bool `json:"adopt_display_name,omitempty"` + // AdoptAvatar holds the value of the "adopt_avatar" field. + AdoptAvatar bool `json:"adopt_avatar,omitempty"` + // DecidedAt holds the value of the "decided_at" field. + DecidedAt time.Time `json:"decided_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set. + Edges IdentityAdoptionDecisionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph. +type IdentityAdoptionDecisionEdges struct { + // PendingAuthSession holds the value of the pending_auth_session edge. + PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"` + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) { + if e.PendingAuthSession != nil { + return e.PendingAuthSession, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: pendingauthsession.Label} + } + return nil, &NotLoadedError{edge: "pending_auth_session"} +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar: + values[i] = new(sql.NullBool) + case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID: + values[i] = new(sql.NullInt64) + case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdentityAdoptionDecision fields. +func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case identityadoptiondecision.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case identityadoptiondecision.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case identityadoptiondecision.FieldPendingAuthSessionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i]) + } else if value.Valid { + _m.PendingAuthSessionID = value.Int64 + } + case identityadoptiondecision.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = new(int64) + *_m.IdentityID = value.Int64 + } + case identityadoptiondecision.FieldAdoptDisplayName: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i]) + } else if value.Valid { + _m.AdoptDisplayName = value.Bool + } + case identityadoptiondecision.FieldAdoptAvatar: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i]) + } else if value.Valid { + _m.AdoptAvatar = value.Bool + } + case identityadoptiondecision.FieldDecidedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field decided_at", values[i]) + } else if value.Valid { + _m.DecidedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision. +// This includes values selected through modifiers, order, etc. +func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m) +} + +// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this IdentityAdoptionDecision. +// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne { + return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdentityAdoptionDecision is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdentityAdoptionDecision) String() string { + var builder strings.Builder + builder.WriteString("IdentityAdoptionDecision(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pending_auth_session_id=") + builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID)) + builder.WriteString(", ") + if v := _m.IdentityID; v != nil { + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("adopt_display_name=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName)) + builder.WriteString(", ") + builder.WriteString("adopt_avatar=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar)) + builder.WriteString(", ") + builder.WriteString("decided_at=") + builder.WriteString(_m.DecidedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision. +type IdentityAdoptionDecisions []*IdentityAdoptionDecision diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go new file mode 100644 index 0000000000000000000000000000000000000000..93adaf7397c4d18c07fee9601737b17f571905ef --- /dev/null +++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go @@ -0,0 +1,159 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the identityadoptiondecision type in the database. + Label = "identity_adoption_decision" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database. + FieldPendingAuthSessionID = "pending_auth_session_id" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database. + FieldAdoptDisplayName = "adopt_display_name" + // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database. + FieldAdoptAvatar = "adopt_avatar" + // FieldDecidedAt holds the string denoting the decided_at field in the database. + FieldDecidedAt = "decided_at" + // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations. + EdgePendingAuthSession = "pending_auth_session" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the identityadoptiondecision in the database. + Table = "identity_adoption_decisions" + // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge. + PendingAuthSessionTable = "identity_adoption_decisions" + // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionInverseTable = "pending_auth_sessions" + // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge. + PendingAuthSessionColumn = "pending_auth_session_id" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "identity_adoption_decisions" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for identityadoptiondecision fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldPendingAuthSessionID, + FieldIdentityID, + FieldAdoptDisplayName, + FieldAdoptAvatar, + FieldDecidedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field. + DefaultAdoptDisplayName bool + // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field. + DefaultAdoptAvatar bool + // DefaultDecidedAt holds the default value on creation for the "decided_at" field. + DefaultDecidedAt func() time.Time +) + +// OrderOption defines the ordering options for the IdentityAdoptionDecision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionID orders the results by the pending_auth_session_id field. +func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByAdoptDisplayName orders the results by the adopt_display_name field. +func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc() +} + +// ByAdoptAvatar orders the results by the adopt_avatar field. +func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc() +} + +// ByDecidedAt orders the results by the decided_at field. +func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDecidedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionField orders the results by pending_auth_session field. +func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...)) + } +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newPendingAuthSessionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go new file mode 100644 index 0000000000000000000000000000000000000000..1968f175063ac5962573aa4fbfcb2f0742527835 --- /dev/null +++ b/backend/ent/identityadoptiondecision/where.go @@ -0,0 +1,342 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ. +func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ. +func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ. +func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ. +func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...)) +} + +// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field. +func IdentityIDIsNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID)) +} + +// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field. +func IdentityIDNotNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID)) +} + +// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field. +func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field. +func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v)) +} + +// DecidedAtEQ applies the EQ predicate on the "decided_at" field. +func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field. +func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v)) +} + +// DecidedAtIn applies the In predicate on the "decided_at" field. +func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...)) +} + +// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field. +func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...)) +} + +// DecidedAtGT applies the GT predicate on the "decided_at" field. +func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v)) +} + +// DecidedAtGTE applies the GTE predicate on the "decided_at" field. +func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v)) +} + +// DecidedAtLT applies the LT predicate on the "decided_at" field. +func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v)) +} + +// DecidedAtLTE applies the LTE predicate on the "decided_at" field. +func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v)) +} + +// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge. +func HasPendingAuthSession() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates). +func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newPendingAuthSessionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.NotPredicates(p)) +} diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go new file mode 100644 index 0000000000000000000000000000000000000000..491ba9f9a6626000bbc7095ba4390d03fcd00390 --- /dev/null +++ b/backend/ent/identityadoptiondecision_create.go @@ -0,0 +1,843 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionCreate struct { + config + mutation *IdentityAdoptionDecisionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetPendingAuthSessionID(v) + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetIdentityID(*v) + } + return _c +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptDisplayName(v) + return _c +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptDisplayName(*v) + } + return _c +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptAvatar(v) + return _c +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptAvatar(*v) + } + return _c +} + +// SetDecidedAt sets the "decided_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetDecidedAt(v) + return _c +} + +// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetDecidedAt(*v) + } + return _c +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate { + return _c.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation { + return _c.mutation +} + +// Save creates the IdentityAdoptionDecision in the database. +func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdentityAdoptionDecisionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := identityadoptiondecision.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + v := identityadoptiondecision.DefaultAdoptDisplayName + _c.mutation.SetAdoptDisplayName(v) + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + v := identityadoptiondecision.DefaultAdoptAvatar + _c.mutation.SetAdoptAvatar(v) + } + if _, ok := _c.mutation.DecidedAt(); !ok { + v := identityadoptiondecision.DefaultDecidedAt() + _c.mutation.SetDecidedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdentityAdoptionDecisionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)} + } + if _, ok := _c.mutation.PendingAuthSessionID(); !ok { + return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)} + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)} + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)} + } + if _, ok := _c.mutation.DecidedAt(); !ok { + return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)} + } + if len(_c.mutation.PendingAuthSessionIDs()) == 0 { + return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)} + } + return nil +} + +func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) { + var ( + _node = &IdentityAdoptionDecision{config: _c.config} + _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + _node.AdoptDisplayName = value + } + if value, ok := _c.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + _node.AdoptAvatar = value + } + if value, ok := _c.mutation.DecidedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value) + _node.DecidedAt = value + } + if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.PendingAuthSessionID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +type ( + // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing + // one IdentityAdoptionDecision node. + IdentityAdoptionDecisionUpsertOne struct { + create *IdentityAdoptionDecisionCreate + } + + // IdentityAdoptionDecisionUpsert is the "OnConflict" setter. + IdentityAdoptionDecisionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldUpdatedAt) + return u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v) + return u +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldIdentityID) + return u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetNull(identityadoptiondecision.FieldIdentityID) + return u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptDisplayName, v) + return u +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName) + return u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptAvatar, v) + return u +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := u.create.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk. +type IdentityAdoptionDecisionCreateBulk struct { + config + err error + builders []*IdentityAdoptionDecisionCreate + conflict []sql.ConflictOption +} + +// Save creates the IdentityAdoptionDecision entities in the database. +func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdentityAdoptionDecision, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdentityAdoptionDecisionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing +// a bulk of IdentityAdoptionDecision nodes. +type IdentityAdoptionDecisionUpsertBulk struct { + create *IdentityAdoptionDecisionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := b.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..ef3d328d050c8743b8d6080d8d7daca8a6b59d4a --- /dev/null +++ b/backend/ent/identityadoptiondecision_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDelete struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDeleteOne struct { + _d *IdentityAdoptionDecisionDelete +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{identityadoptiondecision.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go new file mode 100644 index 0000000000000000000000000000000000000000..4082d8ee74e4d0210372e7e710032975cde08ceb --- /dev/null +++ b/backend/ent/identityadoptiondecision_query.go @@ -0,0 +1,721 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionQuery struct { + config + ctx *QueryContext + order []identityadoptiondecision.OrderOption + inters []Interceptor + predicates []predicate.IdentityAdoptionDecision + withPendingAuthSession *PendingAuthSessionQuery + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder. +func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first IdentityAdoptionDecision entity from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision was found. +func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{identityadoptiondecision.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdentityAdoptionDecision ID from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found. +func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{identityadoptiondecision.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found. +// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found. +func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{identityadoptiondecision.Label} + default: + return nil, &NotSingularError{identityadoptiondecision.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{identityadoptiondecision.Label} + default: + err = &NotSingularError{identityadoptiondecision.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdentityAdoptionDecisions. +func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]() + return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdentityAdoptionDecision IDs. +func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery { + if _q == nil { + return nil + } + return &IdentityAdoptionDecisionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]identityadoptiondecision.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...), + withPendingAuthSession: _q.withPendingAuthSession.Clone(), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSession = query + return _q +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// GroupBy(identityadoptiondecision.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdentityAdoptionDecisionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = identityadoptiondecision.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// Select(identityadoptiondecision.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q} + sbuild.label = identityadoptiondecision.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations. +func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !identityadoptiondecision.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) { + var ( + nodes = []*IdentityAdoptionDecision{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withPendingAuthSession != nil, + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdentityAdoptionDecision).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdentityAdoptionDecision{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withPendingAuthSession; query != nil { + if err := _q.loadPendingAuthSession(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil { + return nil, err + } + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + fk := nodes[i].PendingAuthSessionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(pendingauthsession.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + if nodes[i].IdentityID == nil { + continue + } + fk := *nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for i := range fields { + if fields[i] != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withPendingAuthSession != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(identityadoptiondecision.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = identityadoptiondecision.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionGroupBy struct { + selector + build *IdentityAdoptionDecisionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionSelect struct { + *IdentityAdoptionDecisionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v) +} + +func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go new file mode 100644 index 0000000000000000000000000000000000000000..0ca21d270aa3dd30254c874b44e25bd7f34bd203 --- /dev/null +++ b/backend/ent/identityadoptiondecision_update.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionUpdate struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdate) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdentityAdoptionDecision entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for _, f := range fields { + if !identityadoptiondecision.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &IdentityAdoptionDecision{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bbba5f6cdec4ba26ce89e5a81f532d30f9..95b68e097a3ec41dd82c83d5d1e0eacbb18a3949 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,12 +13,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -228,6 +236,168 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The ChannelMonitorFunc type is an adapter to allow the use of ordinary function as a Querier. +type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ChannelMonitorFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ChannelMonitorQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q) +} + +// The TraverseChannelMonitor type is an adapter to allow the use of ordinary function as Traverser. +type TraverseChannelMonitor func(context.Context, *ent.ChannelMonitorQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseChannelMonitor) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseChannelMonitor) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ChannelMonitorQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q) +} + +// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary function as a Querier. +type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ChannelMonitorDailyRollupFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q) +} + +// The TraverseChannelMonitorDailyRollup type is an adapter to allow the use of ordinary function as Traverser. +type TraverseChannelMonitorDailyRollup func(context.Context, *ent.ChannelMonitorDailyRollupQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseChannelMonitorDailyRollup) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseChannelMonitorDailyRollup) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q) +} + +// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary function as a Querier. +type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ChannelMonitorHistoryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q) +} + +// The TraverseChannelMonitorHistory type is an adapter to allow the use of ordinary function as Traverser. +type TraverseChannelMonitorHistory func(context.Context, *ent.ChannelMonitorHistoryQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseChannelMonitorHistory) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseChannelMonitorHistory) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q) +} + +// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary function as a Querier. +type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ChannelMonitorRequestTemplateFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q) +} + +// The TraverseChannelMonitorRequestTemplate type is an adapter to allow the use of ordinary function as Traverser. +type TraverseChannelMonitorRequestTemplate func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseChannelMonitorRequestTemplate) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseChannelMonitorRequestTemplate) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) @@ -309,6 +479,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + +// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error) @@ -390,6 +587,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + +// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser. +type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -808,18 +1032,34 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.AuthIdentityQuery: + return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil + case *ent.AuthIdentityChannelQuery: + return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil + case *ent.ChannelMonitorQuery: + return &query[*ent.ChannelMonitorQuery, predicate.ChannelMonitor, channelmonitor.OrderOption]{typ: ent.TypeChannelMonitor, tq: q}, nil + case *ent.ChannelMonitorDailyRollupQuery: + return &query[*ent.ChannelMonitorDailyRollupQuery, predicate.ChannelMonitorDailyRollup, channelmonitordailyrollup.OrderOption]{typ: ent.TypeChannelMonitorDailyRollup, tq: q}, nil + case *ent.ChannelMonitorHistoryQuery: + return &query[*ent.ChannelMonitorHistoryQuery, predicate.ChannelMonitorHistory, channelmonitorhistory.OrderOption]{typ: ent.TypeChannelMonitorHistory, tq: q}, nil + case *ent.ChannelMonitorRequestTemplateQuery: + return &query[*ent.ChannelMonitorRequestTemplateQuery, predicate.ChannelMonitorRequestTemplate, channelmonitorrequesttemplate.OrderOption]{typ: ent.TypeChannelMonitorRequestTemplate, tq: q}, nil case *ent.ErrorPassthroughRuleQuery: return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.IdempotencyRecordQuery: return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil + case *ent.IdentityAdoptionDecisionQuery: + return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil case *ent.PaymentAuditLogQuery: return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil case *ent.PaymentOrderQuery: return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil case *ent.PaymentProviderInstanceQuery: return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil + case *ent.PendingAuthSessionQuery: + return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/auth_identity_fk_ondelete_test.go b/backend/ent/migrate/auth_identity_fk_ondelete_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0e37025a5520fa4cf701b2a41ed8690f2b9a9327 --- /dev/null +++ b/backend/ent/migrate/auth_identity_fk_ondelete_test.go @@ -0,0 +1,73 @@ +package migrate + +import ( + "testing" + + "entgo.io/ent/dialect/entsql" + entschema "entgo.io/ent/dialect/sql/schema" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) { + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete, + ) + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete, + ) + require.Equal( + t, + entschema.Cascade, + findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete, + ) + + require.Equal( + t, + entschema.SetNull, + findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete, + ) + require.Equal( + t, + entschema.SetNull, + findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete, + ) +} + +func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) { + idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no") + require.True(t, idx.Unique) + require.Len(t, idx.Columns, 1) + require.Equal(t, "out_trade_no", idx.Columns[0].Name) + require.NotNil(t, idx.Annotation) + require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where) +} + +func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey { + t.Helper() + + for _, fk := range table.ForeignKeys { + if fk.Symbol == symbol { + return fk + } + } + + require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol) + return nil +} + +func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index { + t.Helper() + + for _, idx := range table.Indexes { + if idx.Name == name { + return idx + } + } + + require.Failf(t, "missing index", "table %s should include index %s", table.Name, name) + return nil +} diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 68bdbf5546839faeaffba965d35f9ac56e616e03..178ae1708460870eb88b07afd658898d6e0538eb 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -338,6 +338,252 @@ var ( }, }, } + // AuthIdentitiesColumns holds the columns for the "auth_identities" table. + AuthIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // AuthIdentitiesTable holds the schema information for the "auth_identities" table. + AuthIdentitiesTable = &schema.Table{ + Name: "auth_identities", + Columns: AuthIdentitiesColumns, + PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identities_users_auth_identities", + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentity_provider_type_provider_key_provider_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]}, + }, + { + Name: "authidentity_user_id", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + }, + { + Name: "authidentity_user_id_provider_type", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]}, + }, + }, + } + // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table. + AuthIdentityChannelsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel", Type: field.TypeString, Size: 20}, + {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "identity_id", Type: field.TypeInt64}, + } + // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table. + AuthIdentityChannelsTable = &schema.Table{ + Name: "auth_identity_channels", + Columns: AuthIdentityChannelsColumns, + PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identity_channels_auth_identities_channels", + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]}, + }, + { + Name: "authidentitychannel_identity_id", + Unique: false, + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + }, + }, + } + // ChannelMonitorsColumns holds the columns for the "channel_monitors" table. + ChannelMonitorsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Size: 100}, + {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}}, + {Name: "endpoint", Type: field.TypeString, Size: 500}, + {Name: "api_key_encrypted", Type: field.TypeString}, + {Name: "primary_model", Type: field.TypeString, Size: 200}, + {Name: "extra_models", Type: field.TypeJSON}, + {Name: "group_name", Type: field.TypeString, Nullable: true, Size: 100, Default: ""}, + {Name: "enabled", Type: field.TypeBool, Default: true}, + {Name: "interval_seconds", Type: field.TypeInt}, + {Name: "last_checked_at", Type: field.TypeTime, Nullable: true}, + {Name: "created_by", Type: field.TypeInt64}, + {Name: "extra_headers", Type: field.TypeJSON}, + {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"}, + {Name: "body_override", Type: field.TypeJSON, Nullable: true}, + {Name: "template_id", Type: field.TypeInt64, Nullable: true}, + } + // ChannelMonitorsTable holds the schema information for the "channel_monitors" table. + ChannelMonitorsTable = &schema.Table{ + Name: "channel_monitors", + Columns: ChannelMonitorsColumns, + PrimaryKey: []*schema.Column{ChannelMonitorsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "channel_monitors_channel_monitor_request_templates_request_template", + Columns: []*schema.Column{ChannelMonitorsColumns[17]}, + RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "channelmonitor_enabled_last_checked_at", + Unique: false, + Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]}, + }, + { + Name: "channelmonitor_provider", + Unique: false, + Columns: []*schema.Column{ChannelMonitorsColumns[4]}, + }, + { + Name: "channelmonitor_group_name", + Unique: false, + Columns: []*schema.Column{ChannelMonitorsColumns[9]}, + }, + { + Name: "channelmonitor_template_id", + Unique: false, + Columns: []*schema.Column{ChannelMonitorsColumns[17]}, + }, + }, + } + // ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table. + ChannelMonitorDailyRollupsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "model", Type: field.TypeString, Size: 200}, + {Name: "bucket_date", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "date"}}, + {Name: "total_checks", Type: field.TypeInt, Default: 0}, + {Name: "ok_count", Type: field.TypeInt, Default: 0}, + {Name: "operational_count", Type: field.TypeInt, Default: 0}, + {Name: "degraded_count", Type: field.TypeInt, Default: 0}, + {Name: "failed_count", Type: field.TypeInt, Default: 0}, + {Name: "error_count", Type: field.TypeInt, Default: 0}, + {Name: "sum_latency_ms", Type: field.TypeInt64, Default: 0}, + {Name: "count_latency", Type: field.TypeInt, Default: 0}, + {Name: "sum_ping_latency_ms", Type: field.TypeInt64, Default: 0}, + {Name: "count_ping_latency", Type: field.TypeInt, Default: 0}, + {Name: "computed_at", Type: field.TypeTime}, + {Name: "monitor_id", Type: field.TypeInt64}, + } + // ChannelMonitorDailyRollupsTable holds the schema information for the "channel_monitor_daily_rollups" table. + ChannelMonitorDailyRollupsTable = &schema.Table{ + Name: "channel_monitor_daily_rollups", + Columns: ChannelMonitorDailyRollupsColumns, + PrimaryKey: []*schema.Column{ChannelMonitorDailyRollupsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "channel_monitor_daily_rollups_channel_monitors_daily_rollups", + Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14]}, + RefColumns: []*schema.Column{ChannelMonitorsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + Indexes: []*schema.Index{ + { + Name: "channelmonitordailyrollup_monitor_id_model_bucket_date", + Unique: true, + Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14], ChannelMonitorDailyRollupsColumns[1], ChannelMonitorDailyRollupsColumns[2]}, + }, + { + Name: "channelmonitordailyrollup_bucket_date", + Unique: false, + Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[2]}, + }, + }, + } + // ChannelMonitorHistoriesColumns holds the columns for the "channel_monitor_histories" table. + ChannelMonitorHistoriesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "model", Type: field.TypeString, Size: 200}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"operational", "degraded", "failed", "error"}}, + {Name: "latency_ms", Type: field.TypeInt, Nullable: true}, + {Name: "ping_latency_ms", Type: field.TypeInt, Nullable: true}, + {Name: "message", Type: field.TypeString, Nullable: true, Size: 500, Default: ""}, + {Name: "checked_at", Type: field.TypeTime}, + {Name: "monitor_id", Type: field.TypeInt64}, + } + // ChannelMonitorHistoriesTable holds the schema information for the "channel_monitor_histories" table. + ChannelMonitorHistoriesTable = &schema.Table{ + Name: "channel_monitor_histories", + Columns: ChannelMonitorHistoriesColumns, + PrimaryKey: []*schema.Column{ChannelMonitorHistoriesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "channel_monitor_histories_channel_monitors_history", + Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]}, + RefColumns: []*schema.Column{ChannelMonitorsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + Indexes: []*schema.Index{ + { + Name: "channelmonitorhistory_monitor_id_model_checked_at", + Unique: false, + Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7], ChannelMonitorHistoriesColumns[1], ChannelMonitorHistoriesColumns[6]}, + }, + { + Name: "channelmonitorhistory_checked_at", + Unique: false, + Columns: []*schema.Column{ChannelMonitorHistoriesColumns[6]}, + }, + }, + } + // ChannelMonitorRequestTemplatesColumns holds the columns for the "channel_monitor_request_templates" table. + ChannelMonitorRequestTemplatesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Size: 100}, + {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}}, + {Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""}, + {Name: "extra_headers", Type: field.TypeJSON}, + {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"}, + {Name: "body_override", Type: field.TypeJSON, Nullable: true}, + } + // ChannelMonitorRequestTemplatesTable holds the schema information for the "channel_monitor_request_templates" table. + ChannelMonitorRequestTemplatesTable = &schema.Table{ + Name: "channel_monitor_request_templates", + Columns: ChannelMonitorRequestTemplatesColumns, + PrimaryKey: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "channelmonitorrequesttemplate_provider_name", + Unique: true, + Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]}, + }, + }, + } // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. ErrorPassthroughRulesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -408,6 +654,7 @@ var ( {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -485,6 +732,49 @@ var ( }, }, } + // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "adopt_display_name", Type: field.TypeBool, Default: false}, + {Name: "adopt_avatar", Type: field.TypeBool, Default: false}, + {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "identity_id", Type: field.TypeInt64, Nullable: true}, + {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true}, + } + // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsTable = &schema.Table{ + Name: "identity_adoption_decisions", + Columns: IdentityAdoptionDecisionsColumns, + PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, + OnDelete: schema.Cascade, + }, + }, + Indexes: []*schema.Index{ + { + Name: "identityadoptiondecision_pending_auth_session_id", + Unique: true, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + }, + { + Name: "identityadoptiondecision_identity_id", + Unique: false, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + }, + }, + } // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table. PaymentAuditLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -528,6 +818,8 @@ var ( {Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "subscription_days", Type: field.TypeInt, Nullable: true}, {Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64}, + {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30}, + {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"}, {Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}}, {Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, @@ -556,7 +848,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "payment_orders_users_payment_orders", - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[39]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -564,38 +856,41 @@ var ( Indexes: []*schema.Index{ { Name: "paymentorder_out_trade_no", - Unique: false, + Unique: true, Columns: []*schema.Column{PaymentOrdersColumns[8]}, + Annotation: &entsql.IndexAnnotation{ + Where: "out_trade_no <> ''", + }, }, { Name: "paymentorder_user_id", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[39]}, }, { Name: "paymentorder_status", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[19]}, + Columns: []*schema.Column{PaymentOrdersColumns[21]}, }, { Name: "paymentorder_expires_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[27]}, + Columns: []*schema.Column{PaymentOrdersColumns[29]}, }, { Name: "paymentorder_created_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[35]}, + Columns: []*schema.Column{PaymentOrdersColumns[37]}, }, { Name: "paymentorder_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[30]}, }, { Name: "paymentorder_payment_type_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]}, }, { Name: "paymentorder_order_type", @@ -638,6 +933,72 @@ var ( }, }, } + // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table. + PendingAuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "session_token", Type: field.TypeString, Size: 255}, + {Name: "intent", Type: field.TypeString, Size: 40}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "target_user_id", Type: field.TypeInt64, Nullable: true}, + } + // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table. + PendingAuthSessionsTable = &schema.Table{ + Name: "pending_auth_sessions", + Columns: PendingAuthSessionsColumns, + PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pending_auth_sessions_users_pending_auth_sessions", + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "pendingauthsession_session_token", + Unique: true, + Columns: []*schema.Column{PendingAuthSessionsColumns[3]}, + }, + { + Name: "pendingauthsession_target_user_id", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + }, + { + Name: "pendingauthsession_expires_at", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[19]}, + }, + { + Name: "pendingauthsession_provider_type_provider_key_provider_subject", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]}, + }, + { + Name: "pendingauthsession_completion_code_hash", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[14]}, + }, + }, + } // PromoCodesColumns holds the columns for the "promo_codes" table. PromoCodesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1079,11 +1440,15 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "signup_source", Type: field.TypeString, Default: "email"}, + {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}}, {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -1318,12 +1683,20 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + AuthIdentitiesTable, + AuthIdentityChannelsTable, + ChannelMonitorsTable, + ChannelMonitorDailyRollupsTable, + ChannelMonitorHistoriesTable, + ChannelMonitorRequestTemplatesTable, ErrorPassthroughRulesTable, GroupsTable, IdempotencyRecordsTable, + IdentityAdoptionDecisionsTable, PaymentAuditLogsTable, PaymentOrdersTable, PaymentProviderInstancesTable, + PendingAuthSessionsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1365,6 +1738,29 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable + AuthIdentitiesTable.Annotation = &entsql.Annotation{ + Table: "auth_identities", + } + AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + AuthIdentityChannelsTable.Annotation = &entsql.Annotation{ + Table: "auth_identity_channels", + } + ChannelMonitorsTable.ForeignKeys[0].RefTable = ChannelMonitorRequestTemplatesTable + ChannelMonitorsTable.Annotation = &entsql.Annotation{ + Table: "channel_monitors", + } + ChannelMonitorDailyRollupsTable.ForeignKeys[0].RefTable = ChannelMonitorsTable + ChannelMonitorDailyRollupsTable.Annotation = &entsql.Annotation{ + Table: "channel_monitor_daily_rollups", + } + ChannelMonitorHistoriesTable.ForeignKeys[0].RefTable = ChannelMonitorsTable + ChannelMonitorHistoriesTable.Annotation = &entsql.Annotation{ + Table: "channel_monitor_histories", + } + ChannelMonitorRequestTemplatesTable.Annotation = &entsql.Annotation{ + Table: "channel_monitor_request_templates", + } ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ Table: "error_passthrough_rules", } @@ -1374,6 +1770,11 @@ func init() { IdempotencyRecordsTable.Annotation = &entsql.Annotation{ Table: "idempotency_records", } + IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable + IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{ + Table: "identity_adoption_decisions", + } PaymentAuditLogsTable.Annotation = &entsql.Annotation{ Table: "payment_audit_logs", } @@ -1384,6 +1785,10 @@ func init() { PaymentProviderInstancesTable.Annotation = &entsql.Annotation{ Table: "payment_provider_instances", } + PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable + PendingAuthSessionsTable.Annotation = &entsql.Annotation{ + Table: "pending_auth_sessions", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 524ccb925f4561a123bdcd7f8449eac02a72c64b..d616e4ae12f16d27f7fc90c63991dc5ed44c09cf 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,12 +17,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -51,32 +59,40 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAPIKey = "APIKey" - TypeAccount = "Account" - TypeAccountGroup = "AccountGroup" - TypeAnnouncement = "Announcement" - TypeAnnouncementRead = "AnnouncementRead" - TypeErrorPassthroughRule = "ErrorPassthroughRule" - TypeGroup = "Group" - TypeIdempotencyRecord = "IdempotencyRecord" - TypePaymentAuditLog = "PaymentAuditLog" - TypePaymentOrder = "PaymentOrder" - TypePaymentProviderInstance = "PaymentProviderInstance" - TypePromoCode = "PromoCode" - TypePromoCodeUsage = "PromoCodeUsage" - TypeProxy = "Proxy" - TypeRedeemCode = "RedeemCode" - TypeSecuritySecret = "SecuritySecret" - TypeSetting = "Setting" - TypeSubscriptionPlan = "SubscriptionPlan" - TypeTLSFingerprintProfile = "TLSFingerprintProfile" - TypeUsageCleanupTask = "UsageCleanupTask" - TypeUsageLog = "UsageLog" - TypeUser = "User" - TypeUserAllowedGroup = "UserAllowedGroup" - TypeUserAttributeDefinition = "UserAttributeDefinition" - TypeUserAttributeValue = "UserAttributeValue" - TypeUserSubscription = "UserSubscription" + TypeAPIKey = "APIKey" + TypeAccount = "Account" + TypeAccountGroup = "AccountGroup" + TypeAnnouncement = "Announcement" + TypeAnnouncementRead = "AnnouncementRead" + TypeAuthIdentity = "AuthIdentity" + TypeAuthIdentityChannel = "AuthIdentityChannel" + TypeChannelMonitor = "ChannelMonitor" + TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup" + TypeChannelMonitorHistory = "ChannelMonitorHistory" + TypeChannelMonitorRequestTemplate = "ChannelMonitorRequestTemplate" + TypeErrorPassthroughRule = "ErrorPassthroughRule" + TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" + TypeIdentityAdoptionDecision = "IdentityAdoptionDecision" + TypePaymentAuditLog = "PaymentAuditLog" + TypePaymentOrder = "PaymentOrder" + TypePaymentProviderInstance = "PaymentProviderInstance" + TypePendingAuthSession = "PendingAuthSession" + TypePromoCode = "PromoCode" + TypePromoCodeUsage = "PromoCodeUsage" + TypeProxy = "Proxy" + TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" + TypeSetting = "Setting" + TypeSubscriptionPlan = "SubscriptionPlan" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" + TypeUsageCleanupTask = "UsageCleanupTask" + TypeUsageLog = "UsageLog" + TypeUser = "User" + TypeUserAllowedGroup = "UserAllowedGroup" + TypeUserAttributeDefinition = "UserAttributeDefinition" + TypeUserAttributeValue = "UserAttributeValue" + TypeUserSubscription = "UserSubscription" ) // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. @@ -6887,49 +6903,45 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } -// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. -type ErrorPassthroughRuleMutation struct { +// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph. +type AuthIdentityMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - name *string - enabled *bool - priority *int - addpriority *int - error_codes *[]int - appenderror_codes []int - keywords *[]string - appendkeywords []string - match_mode *string - platforms *[]string - appendplatforms []string - passthrough_code *bool - response_code *int - addresponse_code *int - passthrough_body *bool - custom_message *string - skip_monitoring *bool - description *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*ErrorPassthroughRule, error) - predicates []predicate.ErrorPassthroughRule + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + provider_subject *string + verified_at *time.Time + issuer *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + user *int64 + cleareduser bool + channels map[int64]struct{} + removedchannels map[int64]struct{} + clearedchannels bool + adoption_decisions map[int64]struct{} + removedadoption_decisions map[int64]struct{} + clearedadoption_decisions bool + done bool + oldValue func(context.Context) (*AuthIdentity, error) + predicates []predicate.AuthIdentity } -var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) +var _ ent.Mutation = (*AuthIdentityMutation)(nil) -// errorpassthroughruleOption allows management of the mutation configuration using functional options. -type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) +// authidentityOption allows management of the mutation configuration using functional options. +type authidentityOption func(*AuthIdentityMutation) -// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. -func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { - m := &ErrorPassthroughRuleMutation{ +// newAuthIdentityMutation creates new mutation for the AuthIdentity entity. +func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation { + m := &AuthIdentityMutation{ config: c, op: op, - typ: TypeErrorPassthroughRule, + typ: TypeAuthIdentity, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -6938,20 +6950,20 @@ func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughru return m } -// withErrorPassthroughRuleID sets the ID field of the mutation. -func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { +// withAuthIdentityID sets the ID field of the mutation. +func withAuthIdentityID(id int64) authidentityOption { + return func(m *AuthIdentityMutation) { var ( err error once sync.Once - value *ErrorPassthroughRule + value *AuthIdentity ) - m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + m.oldValue = func(ctx context.Context) (*AuthIdentity, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + value, err = m.Client().AuthIdentity.Get(ctx, id) } }) return value, err @@ -6960,10 +6972,10 @@ func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { } } -// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. -func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { - m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { +// withAuthIdentity sets the old AuthIdentity of the mutation. +func withAuthIdentity(node *AuthIdentity) authidentityOption { + return func(m *AuthIdentityMutation) { + m.oldValue = func(context.Context) (*AuthIdentity, error) { return node, nil } m.id = &node.ID @@ -6972,7 +6984,7 @@ func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOp // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m ErrorPassthroughRuleMutation) Client() *Client { +func (m AuthIdentityMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -6980,7 +6992,7 @@ func (m ErrorPassthroughRuleMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { +func (m AuthIdentityMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -6991,7 +7003,7 @@ func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { +func (m *AuthIdentityMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -7002,7 +7014,7 @@ func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -7011,19 +7023,19 @@ func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -7031,10 +7043,10 @@ func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -7049,17 +7061,17 @@ func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { +func (m *AuthIdentityMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -7067,10 +7079,10 @@ func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -7085,724 +7097,484 @@ func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { +func (m *AuthIdentityMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetName sets the "name" field. -func (m *ErrorPassthroughRuleMutation) SetName(s string) { - m.name = &s +// SetUserID sets the "user_id" field. +func (m *AuthIdentityMutation) SetUserID(i int64) { + m.user = &i } -// Name returns the value of the "name" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { - v := m.name +// UserID returns the value of the "user_id" field in the mutation. +func (m *AuthIdentityMutation) UserID() (r int64, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUserID returns the old "user_id" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.Name, nil + return oldValue.UserID, nil } -// ResetName resets all changes to the "name" field. -func (m *ErrorPassthroughRuleMutation) ResetName() { - m.name = nil +// ResetUserID resets all changes to the "user_id" field. +func (m *AuthIdentityMutation) ResetUserID() { + m.user = nil } -// SetEnabled sets the "enabled" field. -func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { - m.enabled = &b +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityMutation) SetProviderType(s string) { + m.provider_type = &s } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) { + v := m.provider_type if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldProviderType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) } - return oldValue.Enabled, nil + return oldValue.ProviderType, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *ErrorPassthroughRuleMutation) ResetEnabled() { - m.enabled = nil +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityMutation) ResetProviderType() { + m.provider_type = nil } -// SetPriority sets the "priority" field. -func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { - m.priority = &i - m.addpriority = nil +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityMutation) SetProviderKey(s string) { + m.provider_key = &s } -// Priority returns the value of the "priority" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { - v := m.priority +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key if v == nil { return } return *v, true } -// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { +func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPriority is only allowed on UpdateOne operations") + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPriority requires an ID field in the mutation") + return v, errors.New("OldProviderKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPriority: %w", err) - } - return oldValue.Priority, nil -} - -// AddPriority adds i to the "priority" field. -func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { - if m.addpriority != nil { - *m.addpriority += i - } else { - m.addpriority = &i - } -} - -// AddedPriority returns the value that was added to the "priority" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { - v := m.addpriority - if v == nil { - return + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) } - return *v, true + return oldValue.ProviderKey, nil } -// ResetPriority resets all changes to the "priority" field. -func (m *ErrorPassthroughRuleMutation) ResetPriority() { - m.priority = nil - m.addpriority = nil +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityMutation) ResetProviderKey() { + m.provider_key = nil } -// SetErrorCodes sets the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { - m.error_codes = &i - m.appenderror_codes = nil +// SetProviderSubject sets the "provider_subject" field. +func (m *AuthIdentityMutation) SetProviderSubject(s string) { + m.provider_subject = &s } -// ErrorCodes returns the value of the "error_codes" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { - v := m.error_codes +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject if v == nil { return } return *v, true } -// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { +func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorCodes requires an ID field in the mutation") + return v, errors.New("OldProviderSubject requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) - } - return oldValue.ErrorCodes, nil -} - -// AppendErrorCodes adds i to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { - m.appenderror_codes = append(m.appenderror_codes, i...) -} - -// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { - if len(m.appenderror_codes) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) } - return m.appenderror_codes, true -} - -// ClearErrorCodes clears the value of the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} + return oldValue.ProviderSubject, nil } -// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] - return ok -} - -// ResetErrorCodes resets all changes to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *AuthIdentityMutation) ResetProviderSubject() { + m.provider_subject = nil } -// SetKeywords sets the "keywords" field. -func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { - m.keywords = &s - m.appendkeywords = nil +// SetVerifiedAt sets the "verified_at" field. +func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) { + m.verified_at = &t } -// Keywords returns the value of the "keywords" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { - v := m.keywords +// VerifiedAt returns the value of the "verified_at" field in the mutation. +func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) { + v := m.verified_at if v == nil { return } return *v, true } -// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { +func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldKeywords requires an ID field in the mutation") + return v, errors.New("OldVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldKeywords: %w", err) - } - return oldValue.Keywords, nil -} - -// AppendKeywords adds s to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { - m.appendkeywords = append(m.appendkeywords, s...) -} - -// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { - if len(m.appendkeywords) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err) } - return m.appendkeywords, true + return oldValue.VerifiedAt, nil } -// ClearKeywords clears the value of the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ClearKeywords() { - m.keywords = nil - m.appendkeywords = nil - m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +// ClearVerifiedAt clears the value of the "verified_at" field. +func (m *AuthIdentityMutation) ClearVerifiedAt() { + m.verified_at = nil + m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{} } -// KeywordsCleared returns if the "keywords" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] +// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation. +func (m *AuthIdentityMutation) VerifiedAtCleared() bool { + _, ok := m.clearedFields[authidentity.FieldVerifiedAt] return ok } -// ResetKeywords resets all changes to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ResetKeywords() { - m.keywords = nil - m.appendkeywords = nil - delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +// ResetVerifiedAt resets all changes to the "verified_at" field. +func (m *AuthIdentityMutation) ResetVerifiedAt() { + m.verified_at = nil + delete(m.clearedFields, authidentity.FieldVerifiedAt) } -// SetMatchMode sets the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { - m.match_mode = &s +// SetIssuer sets the "issuer" field. +func (m *AuthIdentityMutation) SetIssuer(s string) { + m.issuer = &s } -// MatchMode returns the value of the "match_mode" field in the mutation. -func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { - v := m.match_mode +// Issuer returns the value of the "issuer" field in the mutation. +func (m *AuthIdentityMutation) Issuer() (r string, exists bool) { + v := m.issuer if v == nil { return } return *v, true } -// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + return v, errors.New("OldIssuer is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMatchMode requires an ID field in the mutation") + return v, errors.New("OldIssuer requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + return v, fmt.Errorf("querying old value for OldIssuer: %w", err) } - return oldValue.MatchMode, nil + return oldValue.Issuer, nil } -// ResetMatchMode resets all changes to the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { - m.match_mode = nil +// ClearIssuer clears the value of the "issuer" field. +func (m *AuthIdentityMutation) ClearIssuer() { + m.issuer = nil + m.clearedFields[authidentity.FieldIssuer] = struct{}{} } -// SetPlatforms sets the "platforms" field. -func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { - m.platforms = &s - m.appendplatforms = nil +// IssuerCleared returns if the "issuer" field was cleared in this mutation. +func (m *AuthIdentityMutation) IssuerCleared() bool { + _, ok := m.clearedFields[authidentity.FieldIssuer] + return ok } -// Platforms returns the value of the "platforms" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { - v := m.platforms +// ResetIssuer resets all changes to the "issuer" field. +func (m *AuthIdentityMutation) ResetIssuer() { + m.issuer = nil + delete(m.clearedFields, authidentity.FieldIssuer) +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata if v == nil { return } return *v, true } -// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { +func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatforms requires an ID field in the mutation") + return v, errors.New("OldMetadata requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) - } - return oldValue.Platforms, nil -} - -// AppendPlatforms adds s to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { - m.appendplatforms = append(m.appendplatforms, s...) -} - -// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { - if len(m.appendplatforms) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) } - return m.appendplatforms, true + return oldValue.Metadata, nil } -// ClearPlatforms clears the value of the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { - m.platforms = nil - m.appendplatforms = nil - m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityMutation) ResetMetadata() { + m.metadata = nil } -// PlatformsCleared returns if the "platforms" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] - return ok +// ClearUser clears the "user" edge to the User entity. +func (m *AuthIdentityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[authidentity.FieldUserID] = struct{}{} } -// ResetPlatforms resets all changes to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { - m.platforms = nil - m.appendplatforms = nil - delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *AuthIdentityMutation) UserCleared() bool { + return m.cleareduser } -// SetPassthroughCode sets the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { - m.passthrough_code = &b +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return } -// PassthroughCode returns the value of the "passthrough_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { - v := m.passthrough_code - if v == nil { - return - } - return *v, true +// ResetUser resets all changes to the "user" edge. +func (m *AuthIdentityMutation) ResetUser() { + m.user = nil + m.cleareduser = false } -// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughCode requires an ID field in the mutation") +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids. +func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) { + if m.channels == nil { + m.channels = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + for i := range ids { + m.channels[ids[i]] = struct{}{} } - return oldValue.PassthroughCode, nil -} - -// ResetPassthroughCode resets all changes to the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { - m.passthrough_code = nil } -// SetResponseCode sets the "response_code" field. -func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { - m.response_code = &i - m.addresponse_code = nil +// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) ClearChannels() { + m.clearedchannels = true } -// ResponseCode returns the value of the "response_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { - v := m.response_code - if v == nil { - return - } - return *v, true +// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared. +func (m *AuthIdentityMutation) ChannelsCleared() bool { + return m.clearedchannels } -// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseCode requires an ID field in the mutation") +// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs. +func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) { + if m.removedchannels == nil { + m.removedchannels = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) + for i := range ids { + delete(m.channels, ids[i]) + m.removedchannels[ids[i]] = struct{}{} } - return oldValue.ResponseCode, nil } -// AddResponseCode adds i to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { - if m.addresponse_code != nil { - *m.addresponse_code += i - } else { - m.addresponse_code = &i +// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) { + for id := range m.removedchannels { + ids = append(ids, id) } + return } -// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { - v := m.addresponse_code - if v == nil { - return +// ChannelsIDs returns the "channels" edge IDs in the mutation. +func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) { + for id := range m.channels { + ids = append(ids, id) } - return *v, true + return } -// ClearResponseCode clears the value of the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { - m.response_code = nil - m.addresponse_code = nil - m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +// ResetChannels resets all changes to the "channels" edge. +func (m *AuthIdentityMutation) ResetChannels() { + m.channels = nil + m.clearedchannels = false + m.removedchannels = nil } -// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] - return ok +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids. +func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) { + if m.adoption_decisions == nil { + m.adoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + m.adoption_decisions[ids[i]] = struct{}{} + } } -// ResetResponseCode resets all changes to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { - m.response_code = nil - m.addresponse_code = nil - delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) ClearAdoptionDecisions() { + m.clearedadoption_decisions = true } -// SetPassthroughBody sets the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { - m.passthrough_body = &b +// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared. +func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool { + return m.clearedadoption_decisions } -// PassthroughBody returns the value of the "passthrough_body" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { - v := m.passthrough_body - if v == nil { - return +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) { + if m.removedadoption_decisions == nil { + m.removedadoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.adoption_decisions, ids[i]) + m.removedadoption_decisions[ids[i]] = struct{}{} } - return *v, true } -// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughBody requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) +// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) { + for id := range m.removedadoption_decisions { + ids = append(ids, id) } - return oldValue.PassthroughBody, nil + return } -// ResetPassthroughBody resets all changes to the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { - m.passthrough_body = nil +// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation. +func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) { + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return } -// SetCustomMessage sets the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { - m.custom_message = &s +// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge. +func (m *AuthIdentityMutation) ResetAdoptionDecisions() { + m.adoption_decisions = nil + m.clearedadoption_decisions = false + m.removedadoption_decisions = nil } -// CustomMessage returns the value of the "custom_message" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { - v := m.custom_message - if v == nil { - return - } - return *v, true +// Where appends a list predicates to the AuthIdentityMutation builder. +func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) { + m.predicates = append(m.predicates, ps...) } -// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCustomMessage requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) +// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] } - return oldValue.CustomMessage, nil -} - -// ClearCustomMessage clears the value of the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { - m.custom_message = nil - m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} + m.Where(p...) } -// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] - return ok +// Op returns the operation name. +func (m *AuthIdentityMutation) Op() Op { + return m.op } -// ResetCustomMessage resets all changes to the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { - m.custom_message = nil - delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +// SetOp allows setting the mutation operation. +func (m *AuthIdentityMutation) SetOp(op Op) { + m.op = op } -// SetSkipMonitoring sets the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { - m.skip_monitoring = &b +// Type returns the node type of this mutation (AuthIdentity). +func (m *AuthIdentityMutation) Type() string { + return m.typ } -// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. -func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { - v := m.skip_monitoring - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentity.FieldCreatedAt) } - return *v, true -} - -// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") + if m.updated_at != nil { + fields = append(fields, authidentity.FieldUpdatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) - } - return oldValue.SkipMonitoring, nil -} - -// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { - m.skip_monitoring = nil -} - -// SetDescription sets the "description" field. -func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { - m.description = &s -} - -// Description returns the value of the "description" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { - v := m.description - if v == nil { - return - } - return *v, true -} - -// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) - } - return oldValue.Description, nil -} - -// ClearDescription clears the value of the "description" field. -func (m *ErrorPassthroughRuleMutation) ClearDescription() { - m.description = nil - m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} -} - -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] - return ok -} - -// ResetDescription resets all changes to the "description" field. -func (m *ErrorPassthroughRuleMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, errorpassthroughrule.FieldDescription) -} - -// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. -func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { - m.predicates = append(m.predicates, ps...) -} - -// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.ErrorPassthroughRule, len(ps)) - for i := range ps { - p[i] = ps[i] - } - m.Where(p...) -} - -// Op returns the operation name. -func (m *ErrorPassthroughRuleMutation) Op() Op { - return m.op -} - -// SetOp allows setting the mutation operation. -func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { - m.op = op -} - -// Type returns the node type of this mutation (ErrorPassthroughRule). -func (m *ErrorPassthroughRuleMutation) Type() string { - return m.typ -} - -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *ErrorPassthroughRuleMutation) Fields() []string { - fields := make([]string, 0, 15) - if m.created_at != nil { - fields = append(fields, errorpassthroughrule.FieldCreatedAt) - } - if m.updated_at != nil { - fields = append(fields, errorpassthroughrule.FieldUpdatedAt) - } - if m.name != nil { - fields = append(fields, errorpassthroughrule.FieldName) - } - if m.enabled != nil { - fields = append(fields, errorpassthroughrule.FieldEnabled) - } - if m.priority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) - } - if m.error_codes != nil { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) - } - if m.keywords != nil { - fields = append(fields, errorpassthroughrule.FieldKeywords) - } - if m.match_mode != nil { - fields = append(fields, errorpassthroughrule.FieldMatchMode) - } - if m.platforms != nil { - fields = append(fields, errorpassthroughrule.FieldPlatforms) + if m.user != nil { + fields = append(fields, authidentity.FieldUserID) } - if m.passthrough_code != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + if m.provider_type != nil { + fields = append(fields, authidentity.FieldProviderType) } - if m.response_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) + if m.provider_key != nil { + fields = append(fields, authidentity.FieldProviderKey) } - if m.passthrough_body != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + if m.provider_subject != nil { + fields = append(fields, authidentity.FieldProviderSubject) } - if m.custom_message != nil { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) + if m.verified_at != nil { + fields = append(fields, authidentity.FieldVerifiedAt) } - if m.skip_monitoring != nil { - fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + if m.issuer != nil { + fields = append(fields, authidentity.FieldIssuer) } - if m.description != nil { - fields = append(fields, errorpassthroughrule.FieldDescription) + if m.metadata != nil { + fields = append(fields, authidentity.FieldMetadata) } return fields } @@ -7810,38 +7582,26 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { +func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentity.FieldCreatedAt: return m.CreatedAt() - case errorpassthroughrule.FieldUpdatedAt: + case authidentity.FieldUpdatedAt: return m.UpdatedAt() - case errorpassthroughrule.FieldName: - return m.Name() - case errorpassthroughrule.FieldEnabled: - return m.Enabled() - case errorpassthroughrule.FieldPriority: - return m.Priority() - case errorpassthroughrule.FieldErrorCodes: - return m.ErrorCodes() - case errorpassthroughrule.FieldKeywords: - return m.Keywords() - case errorpassthroughrule.FieldMatchMode: - return m.MatchMode() - case errorpassthroughrule.FieldPlatforms: - return m.Platforms() - case errorpassthroughrule.FieldPassthroughCode: - return m.PassthroughCode() - case errorpassthroughrule.FieldResponseCode: - return m.ResponseCode() - case errorpassthroughrule.FieldPassthroughBody: - return m.PassthroughBody() - case errorpassthroughrule.FieldCustomMessage: - return m.CustomMessage() - case errorpassthroughrule.FieldSkipMonitoring: - return m.SkipMonitoring() - case errorpassthroughrule.FieldDescription: - return m.Description() + case authidentity.FieldUserID: + return m.UserID() + case authidentity.FieldProviderType: + return m.ProviderType() + case authidentity.FieldProviderKey: + return m.ProviderKey() + case authidentity.FieldProviderSubject: + return m.ProviderSubject() + case authidentity.FieldVerifiedAt: + return m.VerifiedAt() + case authidentity.FieldIssuer: + return m.Issuer() + case authidentity.FieldMetadata: + return m.Metadata() } return nil, false } @@ -7849,178 +7609,114 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentity.FieldCreatedAt: return m.OldCreatedAt(ctx) - case errorpassthroughrule.FieldUpdatedAt: + case authidentity.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case errorpassthroughrule.FieldName: - return m.OldName(ctx) - case errorpassthroughrule.FieldEnabled: - return m.OldEnabled(ctx) - case errorpassthroughrule.FieldPriority: - return m.OldPriority(ctx) - case errorpassthroughrule.FieldErrorCodes: - return m.OldErrorCodes(ctx) - case errorpassthroughrule.FieldKeywords: - return m.OldKeywords(ctx) - case errorpassthroughrule.FieldMatchMode: - return m.OldMatchMode(ctx) - case errorpassthroughrule.FieldPlatforms: - return m.OldPlatforms(ctx) - case errorpassthroughrule.FieldPassthroughCode: - return m.OldPassthroughCode(ctx) - case errorpassthroughrule.FieldResponseCode: - return m.OldResponseCode(ctx) - case errorpassthroughrule.FieldPassthroughBody: - return m.OldPassthroughBody(ctx) - case errorpassthroughrule.FieldCustomMessage: - return m.OldCustomMessage(ctx) - case errorpassthroughrule.FieldSkipMonitoring: - return m.OldSkipMonitoring(ctx) - case errorpassthroughrule.FieldDescription: - return m.OldDescription(ctx) + case authidentity.FieldUserID: + return m.OldUserID(ctx) + case authidentity.FieldProviderType: + return m.OldProviderType(ctx) + case authidentity.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentity.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case authidentity.FieldVerifiedAt: + return m.OldVerifiedAt(ctx) + case authidentity.FieldIssuer: + return m.OldIssuer(ctx) + case authidentity.FieldMetadata: + return m.OldMetadata(ctx) } - return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return nil, fmt.Errorf("unknown AuthIdentity field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { +func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentity.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case errorpassthroughrule.FieldUpdatedAt: + case authidentity.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case errorpassthroughrule.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case errorpassthroughrule.FieldEnabled: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetEnabled(v) - return nil - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPriority(v) - return nil - case errorpassthroughrule.FieldErrorCodes: - v, ok := value.([]int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetErrorCodes(v) - return nil - case errorpassthroughrule.FieldKeywords: - v, ok := value.([]string) + case authidentity.FieldUserID: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetKeywords(v) + m.SetUserID(v) return nil - case errorpassthroughrule.FieldMatchMode: + case authidentity.FieldProviderType: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetMatchMode(v) - return nil - case errorpassthroughrule.FieldPlatforms: - v, ok := value.([]string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlatforms(v) + m.SetProviderType(v) return nil - case errorpassthroughrule.FieldPassthroughCode: - v, ok := value.(bool) + case authidentity.FieldProviderKey: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughCode(v) + m.SetProviderKey(v) return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) + case authidentity.FieldProviderSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseCode(v) + m.SetProviderSubject(v) return nil - case errorpassthroughrule.FieldPassthroughBody: - v, ok := value.(bool) + case authidentity.FieldVerifiedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughBody(v) + m.SetVerifiedAt(v) return nil - case errorpassthroughrule.FieldCustomMessage: + case authidentity.FieldIssuer: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCustomMessage(v) - return nil - case errorpassthroughrule.FieldSkipMonitoring: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSkipMonitoring(v) + m.SetIssuer(v) return nil - case errorpassthroughrule.FieldDescription: - v, ok := value.(string) + case authidentity.FieldMetadata: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetDescription(v) + m.SetMetadata(v) return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentity field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *ErrorPassthroughRuleMutation) AddedFields() []string { +func (m *AuthIdentityMutation) AddedFields() []string { var fields []string - if m.addpriority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) - } - if m.addresponse_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { +func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) { switch name { - case errorpassthroughrule.FieldPriority: - return m.AddedPriority() - case errorpassthroughrule.FieldResponseCode: - return m.AddedResponseCode() } return nil, false } @@ -8028,268 +7724,242 @@ func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { +func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error { switch name { - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPriority(v) - return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddResponseCode(v) - return nil } - return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) + return fmt.Errorf("unknown AuthIdentity numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { +func (m *AuthIdentityMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) - } - if m.FieldCleared(errorpassthroughrule.FieldKeywords) { - fields = append(fields, errorpassthroughrule.FieldKeywords) - } - if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { - fields = append(fields, errorpassthroughrule.FieldPlatforms) - } - if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { - fields = append(fields, errorpassthroughrule.FieldResponseCode) + if m.FieldCleared(authidentity.FieldVerifiedAt) { + fields = append(fields, authidentity.FieldVerifiedAt) } - if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) - } - if m.FieldCleared(errorpassthroughrule.FieldDescription) { - fields = append(fields, errorpassthroughrule.FieldDescription) + if m.FieldCleared(authidentity.FieldIssuer) { + fields = append(fields, authidentity.FieldIssuer) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { +func (m *AuthIdentityMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { +func (m *AuthIdentityMutation) ClearField(name string) error { switch name { - case errorpassthroughrule.FieldErrorCodes: - m.ClearErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ClearKeywords() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ClearPlatforms() - return nil - case errorpassthroughrule.FieldResponseCode: - m.ClearResponseCode() - return nil - case errorpassthroughrule.FieldCustomMessage: - m.ClearCustomMessage() + case authidentity.FieldVerifiedAt: + m.ClearVerifiedAt() return nil - case errorpassthroughrule.FieldDescription: - m.ClearDescription() + case authidentity.FieldIssuer: + m.ClearIssuer() return nil } - return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) + return fmt.Errorf("unknown AuthIdentity nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { +func (m *AuthIdentityMutation) ResetField(name string) error { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentity.FieldCreatedAt: m.ResetCreatedAt() return nil - case errorpassthroughrule.FieldUpdatedAt: + case authidentity.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case errorpassthroughrule.FieldName: - m.ResetName() - return nil - case errorpassthroughrule.FieldEnabled: - m.ResetEnabled() - return nil - case errorpassthroughrule.FieldPriority: - m.ResetPriority() - return nil - case errorpassthroughrule.FieldErrorCodes: - m.ResetErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ResetKeywords() - return nil - case errorpassthroughrule.FieldMatchMode: - m.ResetMatchMode() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ResetPlatforms() + case authidentity.FieldUserID: + m.ResetUserID() return nil - case errorpassthroughrule.FieldPassthroughCode: - m.ResetPassthroughCode() + case authidentity.FieldProviderType: + m.ResetProviderType() return nil - case errorpassthroughrule.FieldResponseCode: - m.ResetResponseCode() + case authidentity.FieldProviderKey: + m.ResetProviderKey() return nil - case errorpassthroughrule.FieldPassthroughBody: - m.ResetPassthroughBody() + case authidentity.FieldProviderSubject: + m.ResetProviderSubject() return nil - case errorpassthroughrule.FieldCustomMessage: - m.ResetCustomMessage() + case authidentity.FieldVerifiedAt: + m.ResetVerifiedAt() return nil - case errorpassthroughrule.FieldSkipMonitoring: - m.ResetSkipMonitoring() + case authidentity.FieldIssuer: + m.ResetIssuer() return nil - case errorpassthroughrule.FieldDescription: - m.ResetDescription() + case authidentity.FieldMetadata: + m.ResetMetadata() return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentity field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.user != nil { + edges = append(edges, authidentity.EdgeUser) + } + if m.channels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.adoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { +func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.channels)) + for id := range m.channels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.adoption_decisions)) + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return ids + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedchannels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.removedadoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { +func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.removedchannels)) + for id := range m.removedchannels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.removedadoption_decisions)) + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.cleareduser { + edges = append(edges, authidentity.EdgeUser) + } + if m.clearedchannels { + edges = append(edges, authidentity.EdgeChannels) + } + if m.clearedadoption_decisions { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { +func (m *AuthIdentityMutation) EdgeCleared(name string) bool { + switch name { + case authidentity.EdgeUser: + return m.cleareduser + case authidentity.EdgeChannels: + return m.clearedchannels + case authidentity.EdgeAdoptionDecisions: + return m.clearedadoption_decisions + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +func (m *AuthIdentityMutation) ClearEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown AuthIdentity unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) +func (m *AuthIdentityMutation) ResetEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ResetUser() + return nil + case authidentity.EdgeChannels: + m.ResetChannels() + return nil + case authidentity.EdgeAdoptionDecisions: + m.ResetAdoptionDecisions() + return nil + } + return fmt.Errorf("unknown AuthIdentity edge %s", name) } -// GroupMutation represents an operation that mutates the Group nodes in the graph. -type GroupMutation struct { +// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph. +type AuthIdentityChannelMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - fallback_group_id_on_invalid_request *int64 - addfallback_group_id_on_invalid_request *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - mcp_xml_inject *bool - supported_model_scopes *[]string - appendsupported_model_scopes []string - sort_order *int - addsort_order *int - allow_messages_dispatch *bool - require_oauth_only *bool - require_privacy_set *bool - default_mapped_model *string - messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + channel *string + channel_app_id *string + channel_subject *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*AuthIdentityChannel, error) + predicates []predicate.AuthIdentityChannel } -var _ ent.Mutation = (*GroupMutation)(nil) +var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil) -// groupOption allows management of the mutation configuration using functional options. -type groupOption func(*GroupMutation) +// authidentitychannelOption allows management of the mutation configuration using functional options. +type authidentitychannelOption func(*AuthIdentityChannelMutation) -// newGroupMutation creates new mutation for the Group entity. -func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { - m := &GroupMutation{ +// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity. +func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation { + m := &AuthIdentityChannelMutation{ config: c, op: op, - typ: TypeGroup, + typ: TypeAuthIdentityChannel, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -8298,20 +7968,20 @@ func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { return m } -// withGroupID sets the ID field of the mutation. -func withGroupID(id int64) groupOption { - return func(m *GroupMutation) { +// withAuthIdentityChannelID sets the ID field of the mutation. +func withAuthIdentityChannelID(id int64) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { var ( err error once sync.Once - value *Group + value *AuthIdentityChannel ) - m.oldValue = func(ctx context.Context) (*Group, error) { + m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Group.Get(ctx, id) + value, err = m.Client().AuthIdentityChannel.Get(ctx, id) } }) return value, err @@ -8320,10 +7990,10 @@ func withGroupID(id int64) groupOption { } } -// withGroup sets the old Group of the mutation. -func withGroup(node *Group) groupOption { - return func(m *GroupMutation) { - m.oldValue = func(context.Context) (*Group, error) { +// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation. +func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + m.oldValue = func(context.Context) (*AuthIdentityChannel, error) { return node, nil } m.id = &node.ID @@ -8332,7 +8002,7 @@ func withGroup(node *Group) groupOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m GroupMutation) Client() *Client { +func (m AuthIdentityChannelMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -8340,7 +8010,7 @@ func (m GroupMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m GroupMutation) Tx() (*Tx, error) { +func (m AuthIdentityChannelMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -8351,7 +8021,7 @@ func (m GroupMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *GroupMutation) ID() (id int64, exists bool) { +func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -8362,7 +8032,7 @@ func (m *GroupMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -8371,19 +8041,19 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *GroupMutation) SetCreatedAt(t time.Time) { +func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -8391,10 +8061,10 @@ func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -8409,17 +8079,17 @@ func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err erro } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *GroupMutation) ResetCreatedAt() { +func (m *AuthIdentityChannelMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *GroupMutation) SetUpdatedAt(t time.Time) { +func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -8427,10 +8097,10 @@ func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -8445,1737 +8115,1646 @@ func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err erro } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *GroupMutation) ResetUpdatedAt() { +func (m *AuthIdentityChannelMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetDeletedAt sets the "deleted_at" field. -func (m *GroupMutation) SetDeletedAt(t time.Time) { - m.deleted_at = &t +// SetIdentityID sets the "identity_id" field. +func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) { + m.identity = &i } -// DeletedAt returns the value of the "deleted_at" field in the mutation. -func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { - v := m.deleted_at +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) { + v := m.identity if v == nil { return } return *v, true } -// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { +func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDeletedAt requires an ID field in the mutation") + return v, errors.New("OldIdentityID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) } - return oldValue.DeletedAt, nil -} - -// ClearDeletedAt clears the value of the "deleted_at" field. -func (m *GroupMutation) ClearDeletedAt() { - m.deleted_at = nil - m.clearedFields[group.FieldDeletedAt] = struct{}{} -} - -// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. -func (m *GroupMutation) DeletedAtCleared() bool { - _, ok := m.clearedFields[group.FieldDeletedAt] - return ok + return oldValue.IdentityID, nil } -// ResetDeletedAt resets all changes to the "deleted_at" field. -func (m *GroupMutation) ResetDeletedAt() { - m.deleted_at = nil - delete(m.clearedFields, group.FieldDeletedAt) +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *AuthIdentityChannelMutation) ResetIdentityID() { + m.identity = nil } -// SetName sets the "name" field. -func (m *GroupMutation) SetName(s string) { - m.name = &s +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityChannelMutation) SetProviderType(s string) { + m.provider_type = &s } -// Name returns the value of the "name" field in the mutation. -func (m *GroupMutation) Name() (r string, exists bool) { - v := m.name +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) { + v := m.provider_type if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { +func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldProviderType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) } - return oldValue.Name, nil + return oldValue.ProviderType, nil } -// ResetName resets all changes to the "name" field. -func (m *GroupMutation) ResetName() { - m.name = nil +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityChannelMutation) ResetProviderType() { + m.provider_type = nil } -// SetDescription sets the "description" field. -func (m *GroupMutation) SetDescription(s string) { - m.description = &s +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityChannelMutation) SetProviderKey(s string) { + m.provider_key = &s } -// Description returns the value of the "description" field in the mutation. -func (m *GroupMutation) Description() (r string, exists bool) { - v := m.description +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key if v == nil { return } return *v, true } -// OldDescription returns the old "description" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { +func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + return v, errors.New("OldProviderKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) } - return oldValue.Description, nil -} - -// ClearDescription clears the value of the "description" field. -func (m *GroupMutation) ClearDescription() { - m.description = nil - m.clearedFields[group.FieldDescription] = struct{}{} -} - -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *GroupMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[group.FieldDescription] - return ok + return oldValue.ProviderKey, nil } -// ResetDescription resets all changes to the "description" field. -func (m *GroupMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, group.FieldDescription) +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityChannelMutation) ResetProviderKey() { + m.provider_key = nil } -// SetRateMultiplier sets the "rate_multiplier" field. -func (m *GroupMutation) SetRateMultiplier(f float64) { - m.rate_multiplier = &f - m.addrate_multiplier = nil +// SetChannel sets the "channel" field. +func (m *AuthIdentityChannelMutation) SetChannel(s string) { + m.channel = &s } -// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. -func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { - v := m.rate_multiplier +// Channel returns the value of the "channel" field in the mutation. +func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) { + v := m.channel if v == nil { return } return *v, true } -// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { +func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + return v, errors.New("OldChannel is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + return v, errors.New("OldChannel requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) - } - return oldValue.RateMultiplier, nil -} - -// AddRateMultiplier adds f to the "rate_multiplier" field. -func (m *GroupMutation) AddRateMultiplier(f float64) { - if m.addrate_multiplier != nil { - *m.addrate_multiplier += f - } else { - m.addrate_multiplier = &f - } -} - -// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. -func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { - v := m.addrate_multiplier - if v == nil { - return + return v, fmt.Errorf("querying old value for OldChannel: %w", err) } - return *v, true + return oldValue.Channel, nil } -// ResetRateMultiplier resets all changes to the "rate_multiplier" field. -func (m *GroupMutation) ResetRateMultiplier() { - m.rate_multiplier = nil - m.addrate_multiplier = nil +// ResetChannel resets all changes to the "channel" field. +func (m *AuthIdentityChannelMutation) ResetChannel() { + m.channel = nil } -// SetIsExclusive sets the "is_exclusive" field. -func (m *GroupMutation) SetIsExclusive(b bool) { - m.is_exclusive = &b +// SetChannelAppID sets the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) { + m.channel_app_id = &s } -// IsExclusive returns the value of the "is_exclusive" field in the mutation. -func (m *GroupMutation) IsExclusive() (r bool, exists bool) { - v := m.is_exclusive +// ChannelAppID returns the value of the "channel_app_id" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) { + v := m.channel_app_id if v == nil { return } return *v, true } -// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { +func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") + return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIsExclusive requires an ID field in the mutation") + return v, errors.New("OldChannelAppID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) + return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err) } - return oldValue.IsExclusive, nil + return oldValue.ChannelAppID, nil } -// ResetIsExclusive resets all changes to the "is_exclusive" field. -func (m *GroupMutation) ResetIsExclusive() { - m.is_exclusive = nil +// ResetChannelAppID resets all changes to the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) ResetChannelAppID() { + m.channel_app_id = nil } -// SetStatus sets the "status" field. -func (m *GroupMutation) SetStatus(s string) { - m.status = &s +// SetChannelSubject sets the "channel_subject" field. +func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) { + m.channel_subject = &s } -// Status returns the value of the "status" field in the mutation. -func (m *GroupMutation) Status() (r string, exists bool) { - v := m.status +// ChannelSubject returns the value of the "channel_subject" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) { + v := m.channel_subject if v == nil { return } return *v, true } -// OldStatus returns the old "status" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") + return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + return v, errors.New("OldChannelSubject requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err) } - return oldValue.Status, nil + return oldValue.ChannelSubject, nil } -// ResetStatus resets all changes to the "status" field. -func (m *GroupMutation) ResetStatus() { - m.status = nil +// ResetChannelSubject resets all changes to the "channel_subject" field. +func (m *AuthIdentityChannelMutation) ResetChannelSubject() { + m.channel_subject = nil } -// SetPlatform sets the "platform" field. -func (m *GroupMutation) SetPlatform(s string) { - m.platform = &s +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value } -// Platform returns the value of the "platform" field in the mutation. -func (m *GroupMutation) Platform() (r string, exists bool) { - v := m.platform +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata if v == nil { return } return *v, true } -// OldPlatform returns the old "platform" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { +func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatform requires an ID field in the mutation") + return v, errors.New("OldMetadata requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) } - return oldValue.Platform, nil + return oldValue.Metadata, nil } -// ResetPlatform resets all changes to the "platform" field. -func (m *GroupMutation) ResetPlatform() { - m.platform = nil +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityChannelMutation) ResetMetadata() { + m.metadata = nil } -// SetSubscriptionType sets the "subscription_type" field. -func (m *GroupMutation) SetSubscriptionType(s string) { - m.subscription_type = &s +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *AuthIdentityChannelMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{} } -// SubscriptionType returns the value of the "subscription_type" field in the mutation. -func (m *GroupMutation) SubscriptionType() (r string, exists bool) { - v := m.subscription_type - if v == nil { - return - } - return *v, true +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *AuthIdentityChannelMutation) IdentityCleared() bool { + return m.clearedidentity } -// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) } - return oldValue.SubscriptionType, nil + return } -// ResetSubscriptionType resets all changes to the "subscription_type" field. -func (m *GroupMutation) ResetSubscriptionType() { - m.subscription_type = nil +// ResetIdentity resets all changes to the "identity" edge. +func (m *AuthIdentityChannelMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false } -// SetDailyLimitUsd sets the "daily_limit_usd" field. -func (m *GroupMutation) SetDailyLimitUsd(f float64) { - m.daily_limit_usd = &f - m.adddaily_limit_usd = nil +// Where appends a list predicates to the AuthIdentityChannelMutation builder. +func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) { + m.predicates = append(m.predicates, ps...) } -// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. -func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { - v := m.daily_limit_usd - if v == nil { - return +// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentityChannel, len(ps)) + for i := range ps { + p[i] = ps[i] } - return *v, true + m.Where(p...) } -// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") +// Op returns the operation name. +func (m *AuthIdentityChannelMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityChannelMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentityChannel). +func (m *AuthIdentityChannelMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityChannelMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentitychannel.FieldCreatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + if m.updated_at != nil { + fields = append(fields, authidentitychannel.FieldUpdatedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + if m.identity != nil { + fields = append(fields, authidentitychannel.FieldIdentityID) } - return oldValue.DailyLimitUsd, nil + if m.provider_type != nil { + fields = append(fields, authidentitychannel.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentitychannel.FieldProviderKey) + } + if m.channel != nil { + fields = append(fields, authidentitychannel.FieldChannel) + } + if m.channel_app_id != nil { + fields = append(fields, authidentitychannel.FieldChannelAppID) + } + if m.channel_subject != nil { + fields = append(fields, authidentitychannel.FieldChannelSubject) + } + if m.metadata != nil { + fields = append(fields, authidentitychannel.FieldMetadata) + } + return fields } -// AddDailyLimitUsd adds f to the "daily_limit_usd" field. -func (m *GroupMutation) AddDailyLimitUsd(f float64) { - if m.adddaily_limit_usd != nil { - *m.adddaily_limit_usd += f - } else { - m.adddaily_limit_usd = &f +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.CreatedAt() + case authidentitychannel.FieldUpdatedAt: + return m.UpdatedAt() + case authidentitychannel.FieldIdentityID: + return m.IdentityID() + case authidentitychannel.FieldProviderType: + return m.ProviderType() + case authidentitychannel.FieldProviderKey: + return m.ProviderKey() + case authidentitychannel.FieldChannel: + return m.Channel() + case authidentitychannel.FieldChannelAppID: + return m.ChannelAppID() + case authidentitychannel.FieldChannelSubject: + return m.ChannelSubject() + case authidentitychannel.FieldMetadata: + return m.Metadata() } + return nil, false } -// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. -func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { - v := m.adddaily_limit_usd - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentitychannel.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentitychannel.FieldIdentityID: + return m.OldIdentityID(ctx) + case authidentitychannel.FieldProviderType: + return m.OldProviderType(ctx) + case authidentitychannel.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentitychannel.FieldChannel: + return m.OldChannel(ctx) + case authidentitychannel.FieldChannelAppID: + return m.OldChannelAppID(ctx) + case authidentitychannel.FieldChannelSubject: + return m.OldChannelSubject(ctx) + case authidentitychannel.FieldMetadata: + return m.OldMetadata(ctx) } - return *v, true + return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name) } -// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. -func (m *GroupMutation) ClearDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentitychannel.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentitychannel.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentitychannel.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case authidentitychannel.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentitychannel.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentitychannel.FieldChannel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannel(v) + return nil + case authidentitychannel.FieldChannelAppID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelAppID(v) + return nil + case authidentitychannel.FieldChannelSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelSubject(v) + return nil + case authidentitychannel.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } -// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) DailyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldDailyLimitUsd] - return ok +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityChannelMutation) AddedFields() []string { + var fields []string + return fields } -// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. -func (m *GroupMutation) ResetDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - delete(m.clearedFields, group.FieldDailyLimitUsd) +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false } -// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. -func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { - m.weekly_limit_usd = &f - m.addweekly_limit_usd = nil +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name) } -// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. -func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { - v := m.weekly_limit_usd - if v == nil { - return - } - return *v, true +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityChannelMutation) ClearedFields() []string { + return nil } -// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) - } - return oldValue.WeeklyLimitUsd, nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. -func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { - if m.addweekly_limit_usd != nil { - *m.addweekly_limit_usd += f - } else { - m.addweekly_limit_usd = &f +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetField(name string) error { + switch name { + case authidentitychannel.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentitychannel.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentitychannel.FieldIdentityID: + m.ResetIdentityID() + return nil + case authidentitychannel.FieldProviderType: + m.ResetProviderType() + return nil + case authidentitychannel.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentitychannel.FieldChannel: + m.ResetChannel() + return nil + case authidentitychannel.FieldChannelAppID: + m.ResetChannelAppID() + return nil + case authidentitychannel.FieldChannelSubject: + m.ResetChannelSubject() + return nil + case authidentitychannel.FieldMetadata: + m.ResetMetadata() + return nil } + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } -// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { - v := m.addweekly_limit_usd - if v == nil { - return +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityChannelMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.identity != nil { + edges = append(edges, authidentitychannel.EdgeIdentity) } - return *v, true + return edges } -// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. -func (m *GroupMutation) ClearWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentitychannel.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil } -// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) WeeklyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] - return ok +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityChannelMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges } -// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. -func (m *GroupMutation) ResetWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - delete(m.clearedFields, group.FieldWeeklyLimitUsd) +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value { + return nil } -// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. -func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { - m.monthly_limit_usd = &f - m.addmonthly_limit_usd = nil +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityChannelMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedidentity { + edges = append(edges, authidentitychannel.EdgeIdentity) + } + return edges } -// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. -func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { - v := m.monthly_limit_usd - if v == nil { - return +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool { + switch name { + case authidentitychannel.EdgeIdentity: + return m.clearedidentity } - return *v, true + return false } -// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ClearEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ClearIdentity() + return nil } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityChannelMutation) ResetEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ResetIdentity() + return nil } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + return fmt.Errorf("unknown AuthIdentityChannel edge %s", name) +} + +// ChannelMonitorMutation represents an operation that mutates the ChannelMonitor nodes in the graph. +type ChannelMonitorMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + provider *channelmonitor.Provider + endpoint *string + api_key_encrypted *string + primary_model *string + extra_models *[]string + appendextra_models []string + group_name *string + enabled *bool + interval_seconds *int + addinterval_seconds *int + last_checked_at *time.Time + created_by *int64 + addcreated_by *int64 + extra_headers *map[string]string + body_override_mode *string + body_override *map[string]interface{} + clearedFields map[string]struct{} + history map[int64]struct{} + removedhistory map[int64]struct{} + clearedhistory bool + daily_rollups map[int64]struct{} + removeddaily_rollups map[int64]struct{} + cleareddaily_rollups bool + request_template *int64 + clearedrequest_template bool + done bool + oldValue func(context.Context) (*ChannelMonitor, error) + predicates []predicate.ChannelMonitor +} + +var _ ent.Mutation = (*ChannelMonitorMutation)(nil) + +// channelmonitorOption allows management of the mutation configuration using functional options. +type channelmonitorOption func(*ChannelMonitorMutation) + +// newChannelMonitorMutation creates new mutation for the ChannelMonitor entity. +func newChannelMonitorMutation(c config, op Op, opts ...channelmonitorOption) *ChannelMonitorMutation { + m := &ChannelMonitorMutation{ + config: c, + op: op, + typ: TypeChannelMonitor, + clearedFields: make(map[string]struct{}), } - return oldValue.MonthlyLimitUsd, nil + for _, opt := range opts { + opt(m) + } + return m } -// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. -func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { - if m.addmonthly_limit_usd != nil { - *m.addmonthly_limit_usd += f - } else { - m.addmonthly_limit_usd = &f +// withChannelMonitorID sets the ID field of the mutation. +func withChannelMonitorID(id int64) channelmonitorOption { + return func(m *ChannelMonitorMutation) { + var ( + err error + once sync.Once + value *ChannelMonitor + ) + m.oldValue = func(ctx context.Context) (*ChannelMonitor, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().ChannelMonitor.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } } -// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { - v := m.addmonthly_limit_usd - if v == nil { - return +// withChannelMonitor sets the old ChannelMonitor of the mutation. +func withChannelMonitor(node *ChannelMonitor) channelmonitorOption { + return func(m *ChannelMonitorMutation) { + m.oldValue = func(context.Context) (*ChannelMonitor, error) { + return node, nil + } + m.id = &node.ID } - return *v, true } -// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. -func (m *GroupMutation) ClearMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ChannelMonitorMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) MonthlyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] - return ok +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ChannelMonitorMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. -func (m *GroupMutation) ResetMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - delete(m.clearedFields, group.FieldMonthlyLimitUsd) +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ChannelMonitorMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// SetDefaultValidityDays sets the "default_validity_days" field. -func (m *GroupMutation) SetDefaultValidityDays(i int) { - m.default_validity_days = &i - m.adddefault_validity_days = nil +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ChannelMonitorMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().ChannelMonitor.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } } -// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. -func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { - v := m.default_validity_days +// SetCreatedAt sets the "created_at" field. +func (m *ChannelMonitorMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ChannelMonitorMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { +func (m *ChannelMonitorMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) - } - return oldValue.DefaultValidityDays, nil -} - -// AddDefaultValidityDays adds i to the "default_validity_days" field. -func (m *GroupMutation) AddDefaultValidityDays(i int) { - if m.adddefault_validity_days != nil { - *m.adddefault_validity_days += i - } else { - m.adddefault_validity_days = &i - } -} - -// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. -func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { - v := m.adddefault_validity_days - if v == nil { - return + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return *v, true + return oldValue.CreatedAt, nil } -// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. -func (m *GroupMutation) ResetDefaultValidityDays() { - m.default_validity_days = nil - m.adddefault_validity_days = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ChannelMonitorMutation) ResetCreatedAt() { + m.created_at = nil } -// SetImagePrice1k sets the "image_price_1k" field. -func (m *GroupMutation) SetImagePrice1k(f float64) { - m.image_price_1k = &f - m.addimage_price_1k = nil +// SetUpdatedAt sets the "updated_at" field. +func (m *ChannelMonitorMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. -func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { - v := m.image_price_1k +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ChannelMonitorMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { +func (m *ChannelMonitorMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.ImagePrice1k, nil + return oldValue.UpdatedAt, nil } -// AddImagePrice1k adds f to the "image_price_1k" field. -func (m *GroupMutation) AddImagePrice1k(f float64) { - if m.addimage_price_1k != nil { - *m.addimage_price_1k += f - } else { - m.addimage_price_1k = &f - } +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ChannelMonitorMutation) ResetUpdatedAt() { + m.updated_at = nil } -// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. -func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { - v := m.addimage_price_1k +// SetName sets the "name" field. +func (m *ChannelMonitorMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *ChannelMonitorMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// ClearImagePrice1k clears the value of the "image_price_1k" field. -func (m *GroupMutation) ClearImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - m.clearedFields[group.FieldImagePrice1k] = struct{}{} -} - -// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice1kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice1k] - return ok +// OldName returns the old "name" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil } -// ResetImagePrice1k resets all changes to the "image_price_1k" field. -func (m *GroupMutation) ResetImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - delete(m.clearedFields, group.FieldImagePrice1k) +// ResetName resets all changes to the "name" field. +func (m *ChannelMonitorMutation) ResetName() { + m.name = nil } -// SetImagePrice2k sets the "image_price_2k" field. -func (m *GroupMutation) SetImagePrice2k(f float64) { - m.image_price_2k = &f - m.addimage_price_2k = nil +// SetProvider sets the "provider" field. +func (m *ChannelMonitorMutation) SetProvider(c channelmonitor.Provider) { + m.provider = &c } -// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. -func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { - v := m.image_price_2k +// Provider returns the value of the "provider" field in the mutation. +func (m *ChannelMonitorMutation) Provider() (r channelmonitor.Provider, exists bool) { + v := m.provider if v == nil { return } return *v, true } -// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldProvider returns the old "provider" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { +func (m *ChannelMonitorMutation) OldProvider(ctx context.Context) (v channelmonitor.Provider, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") + return v, errors.New("OldProvider is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice2k requires an ID field in the mutation") + return v, errors.New("OldProvider requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) + return v, fmt.Errorf("querying old value for OldProvider: %w", err) } - return oldValue.ImagePrice2k, nil + return oldValue.Provider, nil } -// AddImagePrice2k adds f to the "image_price_2k" field. -func (m *GroupMutation) AddImagePrice2k(f float64) { - if m.addimage_price_2k != nil { - *m.addimage_price_2k += f - } else { - m.addimage_price_2k = &f - } +// ResetProvider resets all changes to the "provider" field. +func (m *ChannelMonitorMutation) ResetProvider() { + m.provider = nil } -// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. -func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { - v := m.addimage_price_2k +// SetEndpoint sets the "endpoint" field. +func (m *ChannelMonitorMutation) SetEndpoint(s string) { + m.endpoint = &s +} + +// Endpoint returns the value of the "endpoint" field in the mutation. +func (m *ChannelMonitorMutation) Endpoint() (r string, exists bool) { + v := m.endpoint if v == nil { return } return *v, true } -// ClearImagePrice2k clears the value of the "image_price_2k" field. -func (m *GroupMutation) ClearImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - m.clearedFields[group.FieldImagePrice2k] = struct{}{} -} - -// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice2kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice2k] - return ok +// OldEndpoint returns the old "endpoint" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorMutation) OldEndpoint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEndpoint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEndpoint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEndpoint: %w", err) + } + return oldValue.Endpoint, nil } -// ResetImagePrice2k resets all changes to the "image_price_2k" field. -func (m *GroupMutation) ResetImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - delete(m.clearedFields, group.FieldImagePrice2k) +// ResetEndpoint resets all changes to the "endpoint" field. +func (m *ChannelMonitorMutation) ResetEndpoint() { + m.endpoint = nil } -// SetImagePrice4k sets the "image_price_4k" field. -func (m *GroupMutation) SetImagePrice4k(f float64) { - m.image_price_4k = &f - m.addimage_price_4k = nil +// SetAPIKeyEncrypted sets the "api_key_encrypted" field. +func (m *ChannelMonitorMutation) SetAPIKeyEncrypted(s string) { + m.api_key_encrypted = &s } -// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. -func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { - v := m.image_price_4k +// APIKeyEncrypted returns the value of the "api_key_encrypted" field in the mutation. +func (m *ChannelMonitorMutation) APIKeyEncrypted() (r string, exists bool) { + v := m.api_key_encrypted if v == nil { return } return *v, true } -// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldAPIKeyEncrypted returns the old "api_key_encrypted" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { +func (m *ChannelMonitorMutation) OldAPIKeyEncrypted(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") + return v, errors.New("OldAPIKeyEncrypted is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice4k requires an ID field in the mutation") + return v, errors.New("OldAPIKeyEncrypted requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) + return v, fmt.Errorf("querying old value for OldAPIKeyEncrypted: %w", err) } - return oldValue.ImagePrice4k, nil + return oldValue.APIKeyEncrypted, nil } -// AddImagePrice4k adds f to the "image_price_4k" field. -func (m *GroupMutation) AddImagePrice4k(f float64) { - if m.addimage_price_4k != nil { - *m.addimage_price_4k += f - } else { - m.addimage_price_4k = &f - } +// ResetAPIKeyEncrypted resets all changes to the "api_key_encrypted" field. +func (m *ChannelMonitorMutation) ResetAPIKeyEncrypted() { + m.api_key_encrypted = nil } -// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. -func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { - v := m.addimage_price_4k +// SetPrimaryModel sets the "primary_model" field. +func (m *ChannelMonitorMutation) SetPrimaryModel(s string) { + m.primary_model = &s +} + +// PrimaryModel returns the value of the "primary_model" field in the mutation. +func (m *ChannelMonitorMutation) PrimaryModel() (r string, exists bool) { + v := m.primary_model if v == nil { return } return *v, true } -// ClearImagePrice4k clears the value of the "image_price_4k" field. -func (m *GroupMutation) ClearImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - m.clearedFields[group.FieldImagePrice4k] = struct{}{} -} - -// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice4kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice4k] - return ok +// OldPrimaryModel returns the old "primary_model" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorMutation) OldPrimaryModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrimaryModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrimaryModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrimaryModel: %w", err) + } + return oldValue.PrimaryModel, nil } -// ResetImagePrice4k resets all changes to the "image_price_4k" field. -func (m *GroupMutation) ResetImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - delete(m.clearedFields, group.FieldImagePrice4k) +// ResetPrimaryModel resets all changes to the "primary_model" field. +func (m *ChannelMonitorMutation) ResetPrimaryModel() { + m.primary_model = nil } -// SetClaudeCodeOnly sets the "claude_code_only" field. -func (m *GroupMutation) SetClaudeCodeOnly(b bool) { - m.claude_code_only = &b +// SetExtraModels sets the "extra_models" field. +func (m *ChannelMonitorMutation) SetExtraModels(s []string) { + m.extra_models = &s + m.appendextra_models = nil } -// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. -func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { - v := m.claude_code_only +// ExtraModels returns the value of the "extra_models" field in the mutation. +func (m *ChannelMonitorMutation) ExtraModels() (r []string, exists bool) { + v := m.extra_models if v == nil { return } return *v, true } -// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldExtraModels returns the old "extra_models" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { +func (m *ChannelMonitorMutation) OldExtraModels(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") + return v, errors.New("OldExtraModels is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") + return v, errors.New("OldExtraModels requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) + return v, fmt.Errorf("querying old value for OldExtraModels: %w", err) } - return oldValue.ClaudeCodeOnly, nil + return oldValue.ExtraModels, nil } -// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. -func (m *GroupMutation) ResetClaudeCodeOnly() { - m.claude_code_only = nil +// AppendExtraModels adds s to the "extra_models" field. +func (m *ChannelMonitorMutation) AppendExtraModels(s []string) { + m.appendextra_models = append(m.appendextra_models, s...) } -// SetFallbackGroupID sets the "fallback_group_id" field. -func (m *GroupMutation) SetFallbackGroupID(i int64) { - m.fallback_group_id = &i - m.addfallback_group_id = nil +// AppendedExtraModels returns the list of values that were appended to the "extra_models" field in this mutation. +func (m *ChannelMonitorMutation) AppendedExtraModels() ([]string, bool) { + if len(m.appendextra_models) == 0 { + return nil, false + } + return m.appendextra_models, true } -// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. -func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { - v := m.fallback_group_id +// ResetExtraModels resets all changes to the "extra_models" field. +func (m *ChannelMonitorMutation) ResetExtraModels() { + m.extra_models = nil + m.appendextra_models = nil +} + +// SetGroupName sets the "group_name" field. +func (m *ChannelMonitorMutation) SetGroupName(s string) { + m.group_name = &s +} + +// GroupName returns the value of the "group_name" field in the mutation. +func (m *ChannelMonitorMutation) GroupName() (r string, exists bool) { + v := m.group_name if v == nil { return } return *v, true } -// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldGroupName returns the old "group_name" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { +func (m *ChannelMonitorMutation) OldGroupName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + return v, errors.New("OldGroupName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + return v, errors.New("OldGroupName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + return v, fmt.Errorf("querying old value for OldGroupName: %w", err) } - return oldValue.FallbackGroupID, nil + return oldValue.GroupName, nil } -// AddFallbackGroupID adds i to the "fallback_group_id" field. -func (m *GroupMutation) AddFallbackGroupID(i int64) { - if m.addfallback_group_id != nil { - *m.addfallback_group_id += i - } else { - m.addfallback_group_id = &i - } +// ClearGroupName clears the value of the "group_name" field. +func (m *ChannelMonitorMutation) ClearGroupName() { + m.group_name = nil + m.clearedFields[channelmonitor.FieldGroupName] = struct{}{} } -// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { - v := m.addfallback_group_id +// GroupNameCleared returns if the "group_name" field was cleared in this mutation. +func (m *ChannelMonitorMutation) GroupNameCleared() bool { + _, ok := m.clearedFields[channelmonitor.FieldGroupName] + return ok +} + +// ResetGroupName resets all changes to the "group_name" field. +func (m *ChannelMonitorMutation) ResetGroupName() { + m.group_name = nil + delete(m.clearedFields, channelmonitor.FieldGroupName) +} + +// SetEnabled sets the "enabled" field. +func (m *ChannelMonitorMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ChannelMonitorMutation) Enabled() (r bool, exists bool) { + v := m.enabled if v == nil { return } return *v, true } -// ClearFallbackGroupID clears the value of the "fallback_group_id" field. -func (m *GroupMutation) ClearFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - m.clearedFields[group.FieldFallbackGroupID] = struct{}{} -} - -// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupID] - return ok +// OldEnabled returns the old "enabled" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil } -// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. -func (m *GroupMutation) ResetFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - delete(m.clearedFields, group.FieldFallbackGroupID) +// ResetEnabled resets all changes to the "enabled" field. +func (m *ChannelMonitorMutation) ResetEnabled() { + m.enabled = nil } -// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { - m.fallback_group_id_on_invalid_request = &i - m.addfallback_group_id_on_invalid_request = nil +// SetIntervalSeconds sets the "interval_seconds" field. +func (m *ChannelMonitorMutation) SetIntervalSeconds(i int) { + m.interval_seconds = &i + m.addinterval_seconds = nil } -// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.fallback_group_id_on_invalid_request +// IntervalSeconds returns the value of the "interval_seconds" field in the mutation. +func (m *ChannelMonitorMutation) IntervalSeconds() (r int, exists bool) { + v := m.interval_seconds if v == nil { return } return *v, true } -// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldIntervalSeconds returns the old "interval_seconds" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { +func (m *ChannelMonitorMutation) OldIntervalSeconds(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + return v, errors.New("OldIntervalSeconds is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + return v, errors.New("OldIntervalSeconds requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + return v, fmt.Errorf("querying old value for OldIntervalSeconds: %w", err) } - return oldValue.FallbackGroupIDOnInvalidRequest, nil + return oldValue.IntervalSeconds, nil } -// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { - if m.addfallback_group_id_on_invalid_request != nil { - *m.addfallback_group_id_on_invalid_request += i +// AddIntervalSeconds adds i to the "interval_seconds" field. +func (m *ChannelMonitorMutation) AddIntervalSeconds(i int) { + if m.addinterval_seconds != nil { + *m.addinterval_seconds += i } else { - m.addfallback_group_id_on_invalid_request = &i + m.addinterval_seconds = &i } } -// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.addfallback_group_id_on_invalid_request +// AddedIntervalSeconds returns the value that was added to the "interval_seconds" field in this mutation. +func (m *ChannelMonitorMutation) AddedIntervalSeconds() (r int, exists bool) { + v := m.addinterval_seconds if v == nil { return } return *v, true } -// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} -} - -// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] - return ok -} - -// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +// ResetIntervalSeconds resets all changes to the "interval_seconds" field. +func (m *ChannelMonitorMutation) ResetIntervalSeconds() { + m.interval_seconds = nil + m.addinterval_seconds = nil } -// SetModelRouting sets the "model_routing" field. -func (m *GroupMutation) SetModelRouting(value map[string][]int64) { - m.model_routing = &value +// SetLastCheckedAt sets the "last_checked_at" field. +func (m *ChannelMonitorMutation) SetLastCheckedAt(t time.Time) { + m.last_checked_at = &t } -// ModelRouting returns the value of the "model_routing" field in the mutation. -func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { - v := m.model_routing +// LastCheckedAt returns the value of the "last_checked_at" field in the mutation. +func (m *ChannelMonitorMutation) LastCheckedAt() (r time.Time, exists bool) { + v := m.last_checked_at if v == nil { return } return *v, true } -// OldModelRouting returns the old "model_routing" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldLastCheckedAt returns the old "last_checked_at" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { +func (m *ChannelMonitorMutation) OldLastCheckedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") + return v, errors.New("OldLastCheckedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRouting requires an ID field in the mutation") + return v, errors.New("OldLastCheckedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) + return v, fmt.Errorf("querying old value for OldLastCheckedAt: %w", err) } - return oldValue.ModelRouting, nil + return oldValue.LastCheckedAt, nil } -// ClearModelRouting clears the value of the "model_routing" field. -func (m *GroupMutation) ClearModelRouting() { - m.model_routing = nil - m.clearedFields[group.FieldModelRouting] = struct{}{} +// ClearLastCheckedAt clears the value of the "last_checked_at" field. +func (m *ChannelMonitorMutation) ClearLastCheckedAt() { + m.last_checked_at = nil + m.clearedFields[channelmonitor.FieldLastCheckedAt] = struct{}{} } -// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. -func (m *GroupMutation) ModelRoutingCleared() bool { - _, ok := m.clearedFields[group.FieldModelRouting] +// LastCheckedAtCleared returns if the "last_checked_at" field was cleared in this mutation. +func (m *ChannelMonitorMutation) LastCheckedAtCleared() bool { + _, ok := m.clearedFields[channelmonitor.FieldLastCheckedAt] return ok } -// ResetModelRouting resets all changes to the "model_routing" field. -func (m *GroupMutation) ResetModelRouting() { - m.model_routing = nil - delete(m.clearedFields, group.FieldModelRouting) +// ResetLastCheckedAt resets all changes to the "last_checked_at" field. +func (m *ChannelMonitorMutation) ResetLastCheckedAt() { + m.last_checked_at = nil + delete(m.clearedFields, channelmonitor.FieldLastCheckedAt) } -// SetModelRoutingEnabled sets the "model_routing_enabled" field. -func (m *GroupMutation) SetModelRoutingEnabled(b bool) { - m.model_routing_enabled = &b +// SetCreatedBy sets the "created_by" field. +func (m *ChannelMonitorMutation) SetCreatedBy(i int64) { + m.created_by = &i + m.addcreated_by = nil } -// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. -func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { - v := m.model_routing_enabled +// CreatedBy returns the value of the "created_by" field in the mutation. +func (m *ChannelMonitorMutation) CreatedBy() (r int64, exists bool) { + v := m.created_by if v == nil { return } return *v, true } -// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedBy returns the old "created_by" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { +func (m *ChannelMonitorMutation) OldCreatedBy(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") + return v, errors.New("OldCreatedBy requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err) } - return oldValue.ModelRoutingEnabled, nil + return oldValue.CreatedBy, nil } -// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. -func (m *GroupMutation) ResetModelRoutingEnabled() { - m.model_routing_enabled = nil +// AddCreatedBy adds i to the "created_by" field. +func (m *ChannelMonitorMutation) AddCreatedBy(i int64) { + if m.addcreated_by != nil { + *m.addcreated_by += i + } else { + m.addcreated_by = &i + } } -// SetMcpXMLInject sets the "mcp_xml_inject" field. -func (m *GroupMutation) SetMcpXMLInject(b bool) { - m.mcp_xml_inject = &b -} - -// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. -func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { - v := m.mcp_xml_inject - if v == nil { - return - } - return *v, true -} - -// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) - } - return oldValue.McpXMLInject, nil -} - -// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. -func (m *GroupMutation) ResetMcpXMLInject() { - m.mcp_xml_inject = nil -} - -// SetSupportedModelScopes sets the "supported_model_scopes" field. -func (m *GroupMutation) SetSupportedModelScopes(s []string) { - m.supported_model_scopes = &s - m.appendsupported_model_scopes = nil -} - -// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. -func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { - v := m.supported_model_scopes - if v == nil { - return - } - return *v, true -} - -// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) - } - return oldValue.SupportedModelScopes, nil -} - -// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. -func (m *GroupMutation) AppendSupportedModelScopes(s []string) { - m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) -} - -// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. -func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { - if len(m.appendsupported_model_scopes) == 0 { - return nil, false - } - return m.appendsupported_model_scopes, true -} - -// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. -func (m *GroupMutation) ResetSupportedModelScopes() { - m.supported_model_scopes = nil - m.appendsupported_model_scopes = nil -} - -// SetSortOrder sets the "sort_order" field. -func (m *GroupMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil -} - -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *GroupMutation) SortOrder() (r int, exists bool) { - v := m.sort_order - if v == nil { - return - } - return *v, true -} - -// OldSortOrder returns the old "sort_order" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) - } - return oldValue.SortOrder, nil -} - -// AddSortOrder adds i to the "sort_order" field. -func (m *GroupMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i - } -} - -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order +// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation. +func (m *ChannelMonitorMutation) AddedCreatedBy() (r int64, exists bool) { + v := m.addcreated_by if v == nil { return } return *v, true } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *GroupMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// ResetCreatedBy resets all changes to the "created_by" field. +func (m *ChannelMonitorMutation) ResetCreatedBy() { + m.created_by = nil + m.addcreated_by = nil } -// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. -func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { - m.allow_messages_dispatch = &b +// SetTemplateID sets the "template_id" field. +func (m *ChannelMonitorMutation) SetTemplateID(i int64) { + m.request_template = &i } -// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. -func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { - v := m.allow_messages_dispatch +// TemplateID returns the value of the "template_id" field in the mutation. +func (m *ChannelMonitorMutation) TemplateID() (r int64, exists bool) { + v := m.request_template if v == nil { return } return *v, true } -// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldTemplateID returns the old "template_id" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { +func (m *ChannelMonitorMutation) OldTemplateID(ctx context.Context) (v *int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") + return v, errors.New("OldTemplateID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") + return v, errors.New("OldTemplateID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) + return v, fmt.Errorf("querying old value for OldTemplateID: %w", err) } - return oldValue.AllowMessagesDispatch, nil -} - -// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. -func (m *GroupMutation) ResetAllowMessagesDispatch() { - m.allow_messages_dispatch = nil -} - -// SetRequireOauthOnly sets the "require_oauth_only" field. -func (m *GroupMutation) SetRequireOauthOnly(b bool) { - m.require_oauth_only = &b + return oldValue.TemplateID, nil } -// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. -func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { - v := m.require_oauth_only - if v == nil { - return - } - return *v, true +// ClearTemplateID clears the value of the "template_id" field. +func (m *ChannelMonitorMutation) ClearTemplateID() { + m.request_template = nil + m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{} } -// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) - } - return oldValue.RequireOauthOnly, nil +// TemplateIDCleared returns if the "template_id" field was cleared in this mutation. +func (m *ChannelMonitorMutation) TemplateIDCleared() bool { + _, ok := m.clearedFields[channelmonitor.FieldTemplateID] + return ok } -// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. -func (m *GroupMutation) ResetRequireOauthOnly() { - m.require_oauth_only = nil +// ResetTemplateID resets all changes to the "template_id" field. +func (m *ChannelMonitorMutation) ResetTemplateID() { + m.request_template = nil + delete(m.clearedFields, channelmonitor.FieldTemplateID) } -// SetRequirePrivacySet sets the "require_privacy_set" field. -func (m *GroupMutation) SetRequirePrivacySet(b bool) { - m.require_privacy_set = &b +// SetExtraHeaders sets the "extra_headers" field. +func (m *ChannelMonitorMutation) SetExtraHeaders(value map[string]string) { + m.extra_headers = &value } -// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. -func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { - v := m.require_privacy_set +// ExtraHeaders returns the value of the "extra_headers" field in the mutation. +func (m *ChannelMonitorMutation) ExtraHeaders() (r map[string]string, exists bool) { + v := m.extra_headers if v == nil { return } return *v, true } -// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { +func (m *ChannelMonitorMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + return v, errors.New("OldExtraHeaders requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err) } - return oldValue.RequirePrivacySet, nil + return oldValue.ExtraHeaders, nil } -// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. -func (m *GroupMutation) ResetRequirePrivacySet() { - m.require_privacy_set = nil +// ResetExtraHeaders resets all changes to the "extra_headers" field. +func (m *ChannelMonitorMutation) ResetExtraHeaders() { + m.extra_headers = nil } -// SetDefaultMappedModel sets the "default_mapped_model" field. -func (m *GroupMutation) SetDefaultMappedModel(s string) { - m.default_mapped_model = &s +// SetBodyOverrideMode sets the "body_override_mode" field. +func (m *ChannelMonitorMutation) SetBodyOverrideMode(s string) { + m.body_override_mode = &s } -// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. -func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { - v := m.default_mapped_model +// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation. +func (m *ChannelMonitorMutation) BodyOverrideMode() (r string, exists bool) { + v := m.body_override_mode if v == nil { return } return *v, true } -// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err) } - return oldValue.DefaultMappedModel, nil + return oldValue.BodyOverrideMode, nil } -// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. -func (m *GroupMutation) ResetDefaultMappedModel() { - m.default_mapped_model = nil +// ResetBodyOverrideMode resets all changes to the "body_override_mode" field. +func (m *ChannelMonitorMutation) ResetBodyOverrideMode() { + m.body_override_mode = nil } -// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. -func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { - m.messages_dispatch_model_config = &damdmc +// SetBodyOverride sets the "body_override" field. +func (m *ChannelMonitorMutation) SetBodyOverride(value map[string]interface{}) { + m.body_override = &value } -// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. -func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { - v := m.messages_dispatch_model_config +// BodyOverride returns the value of the "body_override" field in the mutation. +func (m *ChannelMonitorMutation) BodyOverride() (r map[string]interface{}, exists bool) { + v := m.body_override if v == nil { return } return *v, true } -// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitor entity. +// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { +func (m *ChannelMonitorMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + return v, errors.New("OldBodyOverride requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) - } - return oldValue.MessagesDispatchModelConfig, nil -} - -// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. -func (m *GroupMutation) ResetMessagesDispatchModelConfig() { - m.messages_dispatch_model_config = nil -} - -// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. -func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { - if m.api_keys == nil { - m.api_keys = make(map[int64]struct{}) - } - for i := range ids { - m.api_keys[ids[i]] = struct{}{} - } -} - -// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) ClearAPIKeys() { - m.clearedapi_keys = true -} - -// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. -func (m *GroupMutation) APIKeysCleared() bool { - return m.clearedapi_keys -} - -// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. -func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { - if m.removedapi_keys == nil { - m.removedapi_keys = make(map[int64]struct{}) - } - for i := range ids { - delete(m.api_keys, ids[i]) - m.removedapi_keys[ids[i]] = struct{}{} - } -} - -// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { - for id := range m.removedapi_keys { - ids = append(ids, id) - } - return -} - -// APIKeysIDs returns the "api_keys" edge IDs in the mutation. -func (m *GroupMutation) APIKeysIDs() (ids []int64) { - for id := range m.api_keys { - ids = append(ids, id) - } - return -} - -// ResetAPIKeys resets all changes to the "api_keys" edge. -func (m *GroupMutation) ResetAPIKeys() { - m.api_keys = nil - m.clearedapi_keys = false - m.removedapi_keys = nil -} - -// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. -func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { - if m.redeem_codes == nil { - m.redeem_codes = make(map[int64]struct{}) - } - for i := range ids { - m.redeem_codes[ids[i]] = struct{}{} - } -} - -// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) ClearRedeemCodes() { - m.clearedredeem_codes = true -} - -// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. -func (m *GroupMutation) RedeemCodesCleared() bool { - return m.clearedredeem_codes -} - -// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. -func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { - if m.removedredeem_codes == nil { - m.removedredeem_codes = make(map[int64]struct{}) - } - for i := range ids { - delete(m.redeem_codes, ids[i]) - m.removedredeem_codes[ids[i]] = struct{}{} + return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err) } + return oldValue.BodyOverride, nil } -// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { - for id := range m.removedredeem_codes { - ids = append(ids, id) - } - return -} - -// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. -func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { - for id := range m.redeem_codes { - ids = append(ids, id) - } - return -} - -// ResetRedeemCodes resets all changes to the "redeem_codes" edge. -func (m *GroupMutation) ResetRedeemCodes() { - m.redeem_codes = nil - m.clearedredeem_codes = false - m.removedredeem_codes = nil -} - -// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. -func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { - if m.subscriptions == nil { - m.subscriptions = make(map[int64]struct{}) - } - for i := range ids { - m.subscriptions[ids[i]] = struct{}{} - } -} - -// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) ClearSubscriptions() { - m.clearedsubscriptions = true -} - -// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. -func (m *GroupMutation) SubscriptionsCleared() bool { - return m.clearedsubscriptions -} - -// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. -func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { - if m.removedsubscriptions == nil { - m.removedsubscriptions = make(map[int64]struct{}) - } - for i := range ids { - delete(m.subscriptions, ids[i]) - m.removedsubscriptions[ids[i]] = struct{}{} - } -} - -// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { - for id := range m.removedsubscriptions { - ids = append(ids, id) - } - return +// ClearBodyOverride clears the value of the "body_override" field. +func (m *ChannelMonitorMutation) ClearBodyOverride() { + m.body_override = nil + m.clearedFields[channelmonitor.FieldBodyOverride] = struct{}{} } -// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. -func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { - for id := range m.subscriptions { - ids = append(ids, id) - } - return +// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation. +func (m *ChannelMonitorMutation) BodyOverrideCleared() bool { + _, ok := m.clearedFields[channelmonitor.FieldBodyOverride] + return ok } -// ResetSubscriptions resets all changes to the "subscriptions" edge. -func (m *GroupMutation) ResetSubscriptions() { - m.subscriptions = nil - m.clearedsubscriptions = false - m.removedsubscriptions = nil +// ResetBodyOverride resets all changes to the "body_override" field. +func (m *ChannelMonitorMutation) ResetBodyOverride() { + m.body_override = nil + delete(m.clearedFields, channelmonitor.FieldBodyOverride) } -// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. -func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { - if m.usage_logs == nil { - m.usage_logs = make(map[int64]struct{}) +// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by ids. +func (m *ChannelMonitorMutation) AddHistoryIDs(ids ...int64) { + if m.history == nil { + m.history = make(map[int64]struct{}) } for i := range ids { - m.usage_logs[ids[i]] = struct{}{} + m.history[ids[i]] = struct{}{} } } -// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) ClearUsageLogs() { - m.clearedusage_logs = true +// ClearHistory clears the "history" edge to the ChannelMonitorHistory entity. +func (m *ChannelMonitorMutation) ClearHistory() { + m.clearedhistory = true } -// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. -func (m *GroupMutation) UsageLogsCleared() bool { - return m.clearedusage_logs +// HistoryCleared reports if the "history" edge to the ChannelMonitorHistory entity was cleared. +func (m *ChannelMonitorMutation) HistoryCleared() bool { + return m.clearedhistory } -// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. -func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { - if m.removedusage_logs == nil { - m.removedusage_logs = make(map[int64]struct{}) +// RemoveHistoryIDs removes the "history" edge to the ChannelMonitorHistory entity by IDs. +func (m *ChannelMonitorMutation) RemoveHistoryIDs(ids ...int64) { + if m.removedhistory == nil { + m.removedhistory = make(map[int64]struct{}) } for i := range ids { - delete(m.usage_logs, ids[i]) - m.removedusage_logs[ids[i]] = struct{}{} + delete(m.history, ids[i]) + m.removedhistory[ids[i]] = struct{}{} } } -// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { - for id := range m.removedusage_logs { +// RemovedHistory returns the removed IDs of the "history" edge to the ChannelMonitorHistory entity. +func (m *ChannelMonitorMutation) RemovedHistoryIDs() (ids []int64) { + for id := range m.removedhistory { ids = append(ids, id) } return } -// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. -func (m *GroupMutation) UsageLogsIDs() (ids []int64) { - for id := range m.usage_logs { +// HistoryIDs returns the "history" edge IDs in the mutation. +func (m *ChannelMonitorMutation) HistoryIDs() (ids []int64) { + for id := range m.history { ids = append(ids, id) } return } -// ResetUsageLogs resets all changes to the "usage_logs" edge. -func (m *GroupMutation) ResetUsageLogs() { - m.usage_logs = nil - m.clearedusage_logs = false - m.removedusage_logs = nil +// ResetHistory resets all changes to the "history" edge. +func (m *ChannelMonitorMutation) ResetHistory() { + m.history = nil + m.clearedhistory = false + m.removedhistory = nil } -// AddAccountIDs adds the "accounts" edge to the Account entity by ids. -func (m *GroupMutation) AddAccountIDs(ids ...int64) { - if m.accounts == nil { - m.accounts = make(map[int64]struct{}) +// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by ids. +func (m *ChannelMonitorMutation) AddDailyRollupIDs(ids ...int64) { + if m.daily_rollups == nil { + m.daily_rollups = make(map[int64]struct{}) } for i := range ids { - m.accounts[ids[i]] = struct{}{} + m.daily_rollups[ids[i]] = struct{}{} } } -// ClearAccounts clears the "accounts" edge to the Account entity. -func (m *GroupMutation) ClearAccounts() { - m.clearedaccounts = true +// ClearDailyRollups clears the "daily_rollups" edge to the ChannelMonitorDailyRollup entity. +func (m *ChannelMonitorMutation) ClearDailyRollups() { + m.cleareddaily_rollups = true } -// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. -func (m *GroupMutation) AccountsCleared() bool { - return m.clearedaccounts +// DailyRollupsCleared reports if the "daily_rollups" edge to the ChannelMonitorDailyRollup entity was cleared. +func (m *ChannelMonitorMutation) DailyRollupsCleared() bool { + return m.cleareddaily_rollups } -// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. -func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { - if m.removedaccounts == nil { - m.removedaccounts = make(map[int64]struct{}) +// RemoveDailyRollupIDs removes the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs. +func (m *ChannelMonitorMutation) RemoveDailyRollupIDs(ids ...int64) { + if m.removeddaily_rollups == nil { + m.removeddaily_rollups = make(map[int64]struct{}) } for i := range ids { - delete(m.accounts, ids[i]) - m.removedaccounts[ids[i]] = struct{}{} + delete(m.daily_rollups, ids[i]) + m.removeddaily_rollups[ids[i]] = struct{}{} } } -// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. -func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { - for id := range m.removedaccounts { +// RemovedDailyRollups returns the removed IDs of the "daily_rollups" edge to the ChannelMonitorDailyRollup entity. +func (m *ChannelMonitorMutation) RemovedDailyRollupsIDs() (ids []int64) { + for id := range m.removeddaily_rollups { ids = append(ids, id) } return } -// AccountsIDs returns the "accounts" edge IDs in the mutation. -func (m *GroupMutation) AccountsIDs() (ids []int64) { - for id := range m.accounts { +// DailyRollupsIDs returns the "daily_rollups" edge IDs in the mutation. +func (m *ChannelMonitorMutation) DailyRollupsIDs() (ids []int64) { + for id := range m.daily_rollups { ids = append(ids, id) } return } -// ResetAccounts resets all changes to the "accounts" edge. -func (m *GroupMutation) ResetAccounts() { - m.accounts = nil - m.clearedaccounts = false - m.removedaccounts = nil -} - -// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. -func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { - if m.allowed_users == nil { - m.allowed_users = make(map[int64]struct{}) - } - for i := range ids { - m.allowed_users[ids[i]] = struct{}{} - } +// ResetDailyRollups resets all changes to the "daily_rollups" edge. +func (m *ChannelMonitorMutation) ResetDailyRollups() { + m.daily_rollups = nil + m.cleareddaily_rollups = false + m.removeddaily_rollups = nil } -// ClearAllowedUsers clears the "allowed_users" edge to the User entity. -func (m *GroupMutation) ClearAllowedUsers() { - m.clearedallowed_users = true +// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by id. +func (m *ChannelMonitorMutation) SetRequestTemplateID(id int64) { + m.request_template = &id } -// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. -func (m *GroupMutation) AllowedUsersCleared() bool { - return m.clearedallowed_users +// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity. +func (m *ChannelMonitorMutation) ClearRequestTemplate() { + m.clearedrequest_template = true + m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{} } -// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. -func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { - if m.removedallowed_users == nil { - m.removedallowed_users = make(map[int64]struct{}) - } - for i := range ids { - delete(m.allowed_users, ids[i]) - m.removedallowed_users[ids[i]] = struct{}{} - } +// RequestTemplateCleared reports if the "request_template" edge to the ChannelMonitorRequestTemplate entity was cleared. +func (m *ChannelMonitorMutation) RequestTemplateCleared() bool { + return m.TemplateIDCleared() || m.clearedrequest_template } -// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. -func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { - for id := range m.removedallowed_users { - ids = append(ids, id) +// RequestTemplateID returns the "request_template" edge ID in the mutation. +func (m *ChannelMonitorMutation) RequestTemplateID() (id int64, exists bool) { + if m.request_template != nil { + return *m.request_template, true } return } -// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. -func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { - for id := range m.allowed_users { - ids = append(ids, id) +// RequestTemplateIDs returns the "request_template" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// RequestTemplateID instead. It exists only for internal usage by the builders. +func (m *ChannelMonitorMutation) RequestTemplateIDs() (ids []int64) { + if id := m.request_template; id != nil { + ids = append(ids, *id) } return } -// ResetAllowedUsers resets all changes to the "allowed_users" edge. -func (m *GroupMutation) ResetAllowedUsers() { - m.allowed_users = nil - m.clearedallowed_users = false - m.removedallowed_users = nil +// ResetRequestTemplate resets all changes to the "request_template" edge. +func (m *ChannelMonitorMutation) ResetRequestTemplate() { + m.request_template = nil + m.clearedrequest_template = false } -// Where appends a list predicates to the GroupMutation builder. -func (m *GroupMutation) Where(ps ...predicate.Group) { +// Where appends a list predicates to the ChannelMonitorMutation builder. +func (m *ChannelMonitorMutation) Where(ps ...predicate.ChannelMonitor) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, +// WhereP appends storage-level predicates to the ChannelMonitorMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.Group, len(ps)) +func (m *ChannelMonitorMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ChannelMonitor, len(ps)) for i := range ps { p[i] = ps[i] } @@ -10183,114 +9762,75 @@ func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *GroupMutation) Op() Op { +func (m *ChannelMonitorMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *GroupMutation) SetOp(op Op) { +func (m *ChannelMonitorMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (Group). -func (m *GroupMutation) Type() string { +// Type returns the node type of this mutation (ChannelMonitor). +func (m *ChannelMonitorMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) +func (m *ChannelMonitorMutation) Fields() []string { + fields := make([]string, 0, 17) if m.created_at != nil { - fields = append(fields, group.FieldCreatedAt) + fields = append(fields, channelmonitor.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, group.FieldUpdatedAt) - } - if m.deleted_at != nil { - fields = append(fields, group.FieldDeletedAt) + fields = append(fields, channelmonitor.FieldUpdatedAt) } if m.name != nil { - fields = append(fields, group.FieldName) + fields = append(fields, channelmonitor.FieldName) } - if m.description != nil { - fields = append(fields, group.FieldDescription) + if m.provider != nil { + fields = append(fields, channelmonitor.FieldProvider) } - if m.rate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) + if m.endpoint != nil { + fields = append(fields, channelmonitor.FieldEndpoint) } - if m.is_exclusive != nil { - fields = append(fields, group.FieldIsExclusive) + if m.api_key_encrypted != nil { + fields = append(fields, channelmonitor.FieldAPIKeyEncrypted) } - if m.status != nil { - fields = append(fields, group.FieldStatus) - } - if m.platform != nil { - fields = append(fields, group.FieldPlatform) - } - if m.subscription_type != nil { - fields = append(fields, group.FieldSubscriptionType) - } - if m.daily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) - } - if m.weekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) - } - if m.monthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) - } - if m.default_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) - } - if m.image_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) - } - if m.image_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) - } - if m.image_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) - } - if m.claude_code_only != nil { - fields = append(fields, group.FieldClaudeCodeOnly) - } - if m.fallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) - } - if m.fallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + if m.primary_model != nil { + fields = append(fields, channelmonitor.FieldPrimaryModel) } - if m.model_routing != nil { - fields = append(fields, group.FieldModelRouting) + if m.extra_models != nil { + fields = append(fields, channelmonitor.FieldExtraModels) } - if m.model_routing_enabled != nil { - fields = append(fields, group.FieldModelRoutingEnabled) + if m.group_name != nil { + fields = append(fields, channelmonitor.FieldGroupName) } - if m.mcp_xml_inject != nil { - fields = append(fields, group.FieldMcpXMLInject) + if m.enabled != nil { + fields = append(fields, channelmonitor.FieldEnabled) } - if m.supported_model_scopes != nil { - fields = append(fields, group.FieldSupportedModelScopes) + if m.interval_seconds != nil { + fields = append(fields, channelmonitor.FieldIntervalSeconds) } - if m.sort_order != nil { - fields = append(fields, group.FieldSortOrder) + if m.last_checked_at != nil { + fields = append(fields, channelmonitor.FieldLastCheckedAt) } - if m.allow_messages_dispatch != nil { - fields = append(fields, group.FieldAllowMessagesDispatch) + if m.created_by != nil { + fields = append(fields, channelmonitor.FieldCreatedBy) } - if m.require_oauth_only != nil { - fields = append(fields, group.FieldRequireOauthOnly) + if m.request_template != nil { + fields = append(fields, channelmonitor.FieldTemplateID) } - if m.require_privacy_set != nil { - fields = append(fields, group.FieldRequirePrivacySet) + if m.extra_headers != nil { + fields = append(fields, channelmonitor.FieldExtraHeaders) } - if m.default_mapped_model != nil { - fields = append(fields, group.FieldDefaultMappedModel) + if m.body_override_mode != nil { + fields = append(fields, channelmonitor.FieldBodyOverrideMode) } - if m.messages_dispatch_model_config != nil { - fields = append(fields, group.FieldMessagesDispatchModelConfig) + if m.body_override != nil { + fields = append(fields, channelmonitor.FieldBodyOverride) } return fields } @@ -10298,68 +9838,42 @@ func (m *GroupMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *GroupMutation) Field(name string) (ent.Value, bool) { +func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) { switch name { - case group.FieldCreatedAt: + case channelmonitor.FieldCreatedAt: return m.CreatedAt() - case group.FieldUpdatedAt: + case channelmonitor.FieldUpdatedAt: return m.UpdatedAt() - case group.FieldDeletedAt: - return m.DeletedAt() - case group.FieldName: + case channelmonitor.FieldName: return m.Name() - case group.FieldDescription: - return m.Description() - case group.FieldRateMultiplier: - return m.RateMultiplier() - case group.FieldIsExclusive: - return m.IsExclusive() - case group.FieldStatus: - return m.Status() - case group.FieldPlatform: - return m.Platform() - case group.FieldSubscriptionType: - return m.SubscriptionType() - case group.FieldDailyLimitUsd: - return m.DailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.WeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.MonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.DefaultValidityDays() - case group.FieldImagePrice1k: - return m.ImagePrice1k() - case group.FieldImagePrice2k: - return m.ImagePrice2k() - case group.FieldImagePrice4k: - return m.ImagePrice4k() - case group.FieldClaudeCodeOnly: - return m.ClaudeCodeOnly() - case group.FieldFallbackGroupID: - return m.FallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.FallbackGroupIDOnInvalidRequest() - case group.FieldModelRouting: - return m.ModelRouting() - case group.FieldModelRoutingEnabled: - return m.ModelRoutingEnabled() - case group.FieldMcpXMLInject: - return m.McpXMLInject() - case group.FieldSupportedModelScopes: - return m.SupportedModelScopes() - case group.FieldSortOrder: - return m.SortOrder() - case group.FieldAllowMessagesDispatch: - return m.AllowMessagesDispatch() - case group.FieldRequireOauthOnly: - return m.RequireOauthOnly() - case group.FieldRequirePrivacySet: - return m.RequirePrivacySet() - case group.FieldDefaultMappedModel: - return m.DefaultMappedModel() - case group.FieldMessagesDispatchModelConfig: - return m.MessagesDispatchModelConfig() + case channelmonitor.FieldProvider: + return m.Provider() + case channelmonitor.FieldEndpoint: + return m.Endpoint() + case channelmonitor.FieldAPIKeyEncrypted: + return m.APIKeyEncrypted() + case channelmonitor.FieldPrimaryModel: + return m.PrimaryModel() + case channelmonitor.FieldExtraModels: + return m.ExtraModels() + case channelmonitor.FieldGroupName: + return m.GroupName() + case channelmonitor.FieldEnabled: + return m.Enabled() + case channelmonitor.FieldIntervalSeconds: + return m.IntervalSeconds() + case channelmonitor.FieldLastCheckedAt: + return m.LastCheckedAt() + case channelmonitor.FieldCreatedBy: + return m.CreatedBy() + case channelmonitor.FieldTemplateID: + return m.TemplateID() + case channelmonitor.FieldExtraHeaders: + return m.ExtraHeaders() + case channelmonitor.FieldBodyOverrideMode: + return m.BodyOverrideMode() + case channelmonitor.FieldBodyOverride: + return m.BodyOverride() } return nil, false } @@ -10367,327 +9881,183 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case group.FieldCreatedAt: + case channelmonitor.FieldCreatedAt: return m.OldCreatedAt(ctx) - case group.FieldUpdatedAt: + case channelmonitor.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case group.FieldDeletedAt: - return m.OldDeletedAt(ctx) - case group.FieldName: + case channelmonitor.FieldName: return m.OldName(ctx) - case group.FieldDescription: - return m.OldDescription(ctx) - case group.FieldRateMultiplier: - return m.OldRateMultiplier(ctx) - case group.FieldIsExclusive: - return m.OldIsExclusive(ctx) - case group.FieldStatus: - return m.OldStatus(ctx) - case group.FieldPlatform: - return m.OldPlatform(ctx) - case group.FieldSubscriptionType: - return m.OldSubscriptionType(ctx) - case group.FieldDailyLimitUsd: - return m.OldDailyLimitUsd(ctx) - case group.FieldWeeklyLimitUsd: - return m.OldWeeklyLimitUsd(ctx) - case group.FieldMonthlyLimitUsd: - return m.OldMonthlyLimitUsd(ctx) - case group.FieldDefaultValidityDays: - return m.OldDefaultValidityDays(ctx) - case group.FieldImagePrice1k: - return m.OldImagePrice1k(ctx) - case group.FieldImagePrice2k: - return m.OldImagePrice2k(ctx) - case group.FieldImagePrice4k: - return m.OldImagePrice4k(ctx) - case group.FieldClaudeCodeOnly: - return m.OldClaudeCodeOnly(ctx) - case group.FieldFallbackGroupID: - return m.OldFallbackGroupID(ctx) - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.OldFallbackGroupIDOnInvalidRequest(ctx) - case group.FieldModelRouting: - return m.OldModelRouting(ctx) - case group.FieldModelRoutingEnabled: - return m.OldModelRoutingEnabled(ctx) - case group.FieldMcpXMLInject: - return m.OldMcpXMLInject(ctx) - case group.FieldSupportedModelScopes: - return m.OldSupportedModelScopes(ctx) - case group.FieldSortOrder: - return m.OldSortOrder(ctx) - case group.FieldAllowMessagesDispatch: - return m.OldAllowMessagesDispatch(ctx) - case group.FieldRequireOauthOnly: - return m.OldRequireOauthOnly(ctx) - case group.FieldRequirePrivacySet: - return m.OldRequirePrivacySet(ctx) - case group.FieldDefaultMappedModel: - return m.OldDefaultMappedModel(ctx) - case group.FieldMessagesDispatchModelConfig: - return m.OldMessagesDispatchModelConfig(ctx) + case channelmonitor.FieldProvider: + return m.OldProvider(ctx) + case channelmonitor.FieldEndpoint: + return m.OldEndpoint(ctx) + case channelmonitor.FieldAPIKeyEncrypted: + return m.OldAPIKeyEncrypted(ctx) + case channelmonitor.FieldPrimaryModel: + return m.OldPrimaryModel(ctx) + case channelmonitor.FieldExtraModels: + return m.OldExtraModels(ctx) + case channelmonitor.FieldGroupName: + return m.OldGroupName(ctx) + case channelmonitor.FieldEnabled: + return m.OldEnabled(ctx) + case channelmonitor.FieldIntervalSeconds: + return m.OldIntervalSeconds(ctx) + case channelmonitor.FieldLastCheckedAt: + return m.OldLastCheckedAt(ctx) + case channelmonitor.FieldCreatedBy: + return m.OldCreatedBy(ctx) + case channelmonitor.FieldTemplateID: + return m.OldTemplateID(ctx) + case channelmonitor.FieldExtraHeaders: + return m.OldExtraHeaders(ctx) + case channelmonitor.FieldBodyOverrideMode: + return m.OldBodyOverrideMode(ctx) + case channelmonitor.FieldBodyOverride: + return m.OldBodyOverride(ctx) } - return nil, fmt.Errorf("unknown Group field %s", name) + return nil, fmt.Errorf("unknown ChannelMonitor field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *GroupMutation) SetField(name string, value ent.Value) error { +func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error { switch name { - case group.FieldCreatedAt: + case channelmonitor.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case group.FieldUpdatedAt: + case channelmonitor.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case group.FieldDeletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDeletedAt(v) - return nil - case group.FieldName: + case channelmonitor.FieldName: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetName(v) return nil - case group.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) - return nil - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRateMultiplier(v) - return nil - case group.FieldIsExclusive: - v, ok := value.(bool) + case channelmonitor.FieldProvider: + v, ok := value.(channelmonitor.Provider) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetIsExclusive(v) + m.SetProvider(v) return nil - case group.FieldStatus: + case channelmonitor.FieldEndpoint: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetEndpoint(v) return nil - case group.FieldPlatform: + case channelmonitor.FieldAPIKeyEncrypted: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPlatform(v) + m.SetAPIKeyEncrypted(v) return nil - case group.FieldSubscriptionType: + case channelmonitor.FieldPrimaryModel: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSubscriptionType(v) + m.SetPrimaryModel(v) return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) + case channelmonitor.FieldExtraModels: + v, ok := value.([]string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetDailyLimitUsd(v) + m.SetExtraModels(v) return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) + case channelmonitor.FieldGroupName: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetWeeklyLimitUsd(v) + m.SetGroupName(v) return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) + case channelmonitor.FieldEnabled: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetMonthlyLimitUsd(v) + m.SetEnabled(v) return nil - case group.FieldDefaultValidityDays: + case channelmonitor.FieldIntervalSeconds: v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice4k(v) + m.SetIntervalSeconds(v) return nil - case group.FieldClaudeCodeOnly: - v, ok := value.(bool) + case channelmonitor.FieldLastCheckedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetClaudeCodeOnly(v) + m.SetLastCheckedAt(v) return nil - case group.FieldFallbackGroupID: + case channelmonitor.FieldCreatedBy: v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetFallbackGroupID(v) + m.SetCreatedBy(v) return nil - case group.FieldFallbackGroupIDOnInvalidRequest: + case channelmonitor.FieldTemplateID: v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldModelRouting: - v, ok := value.(map[string][]int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRouting(v) - return nil - case group.FieldModelRoutingEnabled: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRoutingEnabled(v) - return nil - case group.FieldMcpXMLInject: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMcpXMLInject(v) - return nil - case group.FieldSupportedModelScopes: - v, ok := value.([]string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSupportedModelScopes(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSortOrder(v) - return nil - case group.FieldAllowMessagesDispatch: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAllowMessagesDispatch(v) - return nil - case group.FieldRequireOauthOnly: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRequireOauthOnly(v) + m.SetTemplateID(v) return nil - case group.FieldRequirePrivacySet: - v, ok := value.(bool) + case channelmonitor.FieldExtraHeaders: + v, ok := value.(map[string]string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRequirePrivacySet(v) + m.SetExtraHeaders(v) return nil - case group.FieldDefaultMappedModel: + case channelmonitor.FieldBodyOverrideMode: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetDefaultMappedModel(v) + m.SetBodyOverrideMode(v) return nil - case group.FieldMessagesDispatchModelConfig: - v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + case channelmonitor.FieldBodyOverride: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetMessagesDispatchModelConfig(v) + m.SetBodyOverride(v) return nil } - return fmt.Errorf("unknown Group field %s", name) + return fmt.Errorf("unknown ChannelMonitor field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *GroupMutation) AddedFields() []string { +func (m *ChannelMonitorMutation) AddedFields() []string { var fields []string - if m.addrate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) - } - if m.adddaily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) - } - if m.addweekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) - } - if m.addmonthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) - } - if m.adddefault_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) - } - if m.addimage_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) + if m.addinterval_seconds != nil { + fields = append(fields, channelmonitor.FieldIntervalSeconds) } - if m.addimage_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) - } - if m.addimage_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) - } - if m.addfallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) - } - if m.addfallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) - } - if m.addsort_order != nil { - fields = append(fields, group.FieldSortOrder) + if m.addcreated_by != nil { + fields = append(fields, channelmonitor.FieldCreatedBy) } return fields } @@ -10695,30 +10065,12 @@ func (m *GroupMutation) AddedFields() []string { // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { +func (m *ChannelMonitorMutation) AddedField(name string) (ent.Value, bool) { switch name { - case group.FieldRateMultiplier: - return m.AddedRateMultiplier() - case group.FieldDailyLimitUsd: - return m.AddedDailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.AddedWeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.AddedMonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.AddedDefaultValidityDays() - case group.FieldImagePrice1k: - return m.AddedImagePrice1k() - case group.FieldImagePrice2k: - return m.AddedImagePrice2k() - case group.FieldImagePrice4k: - return m.AddedImagePrice4k() - case group.FieldFallbackGroupID: - return m.AddedFallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.AddedFallbackGroupIDOnInvalidRequest() - case group.FieldSortOrder: - return m.AddedSortOrder() + case channelmonitor.FieldIntervalSeconds: + return m.AddedIntervalSeconds() + case channelmonitor.FieldCreatedBy: + return m.AddedCreatedBy() } return nil, false } @@ -10726,404 +10078,195 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *GroupMutation) AddField(name string, value ent.Value) error { +func (m *ChannelMonitorMutation) AddField(name string, value ent.Value) error { switch name { - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRateMultiplier(v) - return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddDailyLimitUsd(v) - return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddWeeklyLimitUsd(v) - return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddMonthlyLimitUsd(v) - return nil - case group.FieldDefaultValidityDays: + case channelmonitor.FieldIntervalSeconds: v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice4k(v) - return nil - case group.FieldFallbackGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFallbackGroupID(v) + m.AddIntervalSeconds(v) return nil - case group.FieldFallbackGroupIDOnInvalidRequest: + case channelmonitor.FieldCreatedBy: v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) + m.AddCreatedBy(v) return nil } - return fmt.Errorf("unknown Group numeric field %s", name) + return fmt.Errorf("unknown ChannelMonitor numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *GroupMutation) ClearedFields() []string { +func (m *ChannelMonitorMutation) ClearedFields() []string { var fields []string - if m.FieldCleared(group.FieldDeletedAt) { - fields = append(fields, group.FieldDeletedAt) + if m.FieldCleared(channelmonitor.FieldGroupName) { + fields = append(fields, channelmonitor.FieldGroupName) } - if m.FieldCleared(group.FieldDescription) { - fields = append(fields, group.FieldDescription) - } - if m.FieldCleared(group.FieldDailyLimitUsd) { - fields = append(fields, group.FieldDailyLimitUsd) - } - if m.FieldCleared(group.FieldWeeklyLimitUsd) { - fields = append(fields, group.FieldWeeklyLimitUsd) - } - if m.FieldCleared(group.FieldMonthlyLimitUsd) { - fields = append(fields, group.FieldMonthlyLimitUsd) - } - if m.FieldCleared(group.FieldImagePrice1k) { - fields = append(fields, group.FieldImagePrice1k) - } - if m.FieldCleared(group.FieldImagePrice2k) { - fields = append(fields, group.FieldImagePrice2k) - } - if m.FieldCleared(group.FieldImagePrice4k) { - fields = append(fields, group.FieldImagePrice4k) + if m.FieldCleared(channelmonitor.FieldLastCheckedAt) { + fields = append(fields, channelmonitor.FieldLastCheckedAt) } - if m.FieldCleared(group.FieldFallbackGroupID) { - fields = append(fields, group.FieldFallbackGroupID) - } - if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + if m.FieldCleared(channelmonitor.FieldTemplateID) { + fields = append(fields, channelmonitor.FieldTemplateID) } - if m.FieldCleared(group.FieldModelRouting) { - fields = append(fields, group.FieldModelRouting) + if m.FieldCleared(channelmonitor.FieldBodyOverride) { + fields = append(fields, channelmonitor.FieldBodyOverride) } return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *GroupMutation) FieldCleared(name string) bool { +func (m *ChannelMonitorMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *GroupMutation) ClearField(name string) error { +func (m *ChannelMonitorMutation) ClearField(name string) error { switch name { - case group.FieldDeletedAt: - m.ClearDeletedAt() - return nil - case group.FieldDescription: - m.ClearDescription() - return nil - case group.FieldDailyLimitUsd: - m.ClearDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ClearWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ClearMonthlyLimitUsd() - return nil - case group.FieldImagePrice1k: - m.ClearImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ClearImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ClearImagePrice4k() + case channelmonitor.FieldGroupName: + m.ClearGroupName() return nil - case group.FieldFallbackGroupID: - m.ClearFallbackGroupID() + case channelmonitor.FieldLastCheckedAt: + m.ClearLastCheckedAt() return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ClearFallbackGroupIDOnInvalidRequest() + case channelmonitor.FieldTemplateID: + m.ClearTemplateID() return nil - case group.FieldModelRouting: - m.ClearModelRouting() + case channelmonitor.FieldBodyOverride: + m.ClearBodyOverride() return nil } - return fmt.Errorf("unknown Group nullable field %s", name) + return fmt.Errorf("unknown ChannelMonitor nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *GroupMutation) ResetField(name string) error { +func (m *ChannelMonitorMutation) ResetField(name string) error { switch name { - case group.FieldCreatedAt: + case channelmonitor.FieldCreatedAt: m.ResetCreatedAt() return nil - case group.FieldUpdatedAt: + case channelmonitor.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case group.FieldDeletedAt: - m.ResetDeletedAt() - return nil - case group.FieldName: + case channelmonitor.FieldName: m.ResetName() return nil - case group.FieldDescription: - m.ResetDescription() - return nil - case group.FieldRateMultiplier: - m.ResetRateMultiplier() - return nil - case group.FieldIsExclusive: - m.ResetIsExclusive() - return nil - case group.FieldStatus: - m.ResetStatus() - return nil - case group.FieldPlatform: - m.ResetPlatform() - return nil - case group.FieldSubscriptionType: - m.ResetSubscriptionType() - return nil - case group.FieldDailyLimitUsd: - m.ResetDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ResetWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ResetMonthlyLimitUsd() - return nil - case group.FieldDefaultValidityDays: - m.ResetDefaultValidityDays() - return nil - case group.FieldImagePrice1k: - m.ResetImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ResetImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ResetImagePrice4k() + case channelmonitor.FieldProvider: + m.ResetProvider() return nil - case group.FieldClaudeCodeOnly: - m.ResetClaudeCodeOnly() + case channelmonitor.FieldEndpoint: + m.ResetEndpoint() return nil - case group.FieldFallbackGroupID: - m.ResetFallbackGroupID() + case channelmonitor.FieldAPIKeyEncrypted: + m.ResetAPIKeyEncrypted() return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ResetFallbackGroupIDOnInvalidRequest() + case channelmonitor.FieldPrimaryModel: + m.ResetPrimaryModel() return nil - case group.FieldModelRouting: - m.ResetModelRouting() + case channelmonitor.FieldExtraModels: + m.ResetExtraModels() return nil - case group.FieldModelRoutingEnabled: - m.ResetModelRoutingEnabled() + case channelmonitor.FieldGroupName: + m.ResetGroupName() return nil - case group.FieldMcpXMLInject: - m.ResetMcpXMLInject() + case channelmonitor.FieldEnabled: + m.ResetEnabled() return nil - case group.FieldSupportedModelScopes: - m.ResetSupportedModelScopes() + case channelmonitor.FieldIntervalSeconds: + m.ResetIntervalSeconds() return nil - case group.FieldSortOrder: - m.ResetSortOrder() + case channelmonitor.FieldLastCheckedAt: + m.ResetLastCheckedAt() return nil - case group.FieldAllowMessagesDispatch: - m.ResetAllowMessagesDispatch() + case channelmonitor.FieldCreatedBy: + m.ResetCreatedBy() return nil - case group.FieldRequireOauthOnly: - m.ResetRequireOauthOnly() + case channelmonitor.FieldTemplateID: + m.ResetTemplateID() return nil - case group.FieldRequirePrivacySet: - m.ResetRequirePrivacySet() + case channelmonitor.FieldExtraHeaders: + m.ResetExtraHeaders() return nil - case group.FieldDefaultMappedModel: - m.ResetDefaultMappedModel() + case channelmonitor.FieldBodyOverrideMode: + m.ResetBodyOverrideMode() return nil - case group.FieldMessagesDispatchModelConfig: - m.ResetMessagesDispatchModelConfig() + case channelmonitor.FieldBodyOverride: + m.ResetBodyOverride() return nil } - return fmt.Errorf("unknown Group field %s", name) + return fmt.Errorf("unknown ChannelMonitor field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 6) - if m.api_keys != nil { - edges = append(edges, group.EdgeAPIKeys) - } - if m.redeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.subscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) - } - if m.usage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) +func (m *ChannelMonitorMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.history != nil { + edges = append(edges, channelmonitor.EdgeHistory) } - if m.accounts != nil { - edges = append(edges, group.EdgeAccounts) + if m.daily_rollups != nil { + edges = append(edges, channelmonitor.EdgeDailyRollups) } - if m.allowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + if m.request_template != nil { + edges = append(edges, channelmonitor.EdgeRequestTemplate) } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *GroupMutation) AddedIDs(name string) []ent.Value { +func (m *ChannelMonitorMutation) AddedIDs(name string) []ent.Value { switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.api_keys)) - for id := range m.api_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.redeem_codes)) - for id := range m.redeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.subscriptions)) - for id := range m.subscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.usage_logs)) - for id := range m.usage_logs { + case channelmonitor.EdgeHistory: + ids := make([]ent.Value, 0, len(m.history)) + for id := range m.history { ids = append(ids, id) } return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.accounts)) - for id := range m.accounts { + case channelmonitor.EdgeDailyRollups: + ids := make([]ent.Value, 0, len(m.daily_rollups)) + for id := range m.daily_rollups { ids = append(ids, id) } return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.allowed_users)) - for id := range m.allowed_users { - ids = append(ids, id) + case channelmonitor.EdgeRequestTemplate: + if id := m.request_template; id != nil { + return []ent.Value{*id} } - return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 6) - if m.removedapi_keys != nil { - edges = append(edges, group.EdgeAPIKeys) - } - if m.removedredeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.removedsubscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) - } - if m.removedusage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) - } - if m.removedaccounts != nil { - edges = append(edges, group.EdgeAccounts) +func (m *ChannelMonitorMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedhistory != nil { + edges = append(edges, channelmonitor.EdgeHistory) } - if m.removedallowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + if m.removeddaily_rollups != nil { + edges = append(edges, channelmonitor.EdgeDailyRollups) } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *GroupMutation) RemovedIDs(name string) []ent.Value { +func (m *ChannelMonitorMutation) RemovedIDs(name string) []ent.Value { switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.removedapi_keys)) - for id := range m.removedapi_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.removedredeem_codes)) - for id := range m.removedredeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.removedsubscriptions)) - for id := range m.removedsubscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.removedusage_logs)) - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.removedaccounts)) - for id := range m.removedaccounts { + case channelmonitor.EdgeHistory: + ids := make([]ent.Value, 0, len(m.removedhistory)) + for id := range m.removedhistory { ids = append(ids, id) } return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.removedallowed_users)) - for id := range m.removedallowed_users { + case channelmonitor.EdgeDailyRollups: + ids := make([]ent.Value, 0, len(m.removeddaily_rollups)) + for id := range m.removeddaily_rollups { ids = append(ids, id) } return ids @@ -11132,118 +10275,110 @@ func (m *GroupMutation) RemovedIDs(name string) []ent.Value { } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 6) - if m.clearedapi_keys { - edges = append(edges, group.EdgeAPIKeys) - } - if m.clearedredeem_codes { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.clearedsubscriptions { - edges = append(edges, group.EdgeSubscriptions) - } - if m.clearedusage_logs { - edges = append(edges, group.EdgeUsageLogs) +func (m *ChannelMonitorMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.clearedhistory { + edges = append(edges, channelmonitor.EdgeHistory) } - if m.clearedaccounts { - edges = append(edges, group.EdgeAccounts) + if m.cleareddaily_rollups { + edges = append(edges, channelmonitor.EdgeDailyRollups) } - if m.clearedallowed_users { - edges = append(edges, group.EdgeAllowedUsers) + if m.clearedrequest_template { + edges = append(edges, channelmonitor.EdgeRequestTemplate) } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *GroupMutation) EdgeCleared(name string) bool { +func (m *ChannelMonitorMutation) EdgeCleared(name string) bool { switch name { - case group.EdgeAPIKeys: - return m.clearedapi_keys - case group.EdgeRedeemCodes: - return m.clearedredeem_codes - case group.EdgeSubscriptions: - return m.clearedsubscriptions - case group.EdgeUsageLogs: - return m.clearedusage_logs - case group.EdgeAccounts: - return m.clearedaccounts - case group.EdgeAllowedUsers: - return m.clearedallowed_users + case channelmonitor.EdgeHistory: + return m.clearedhistory + case channelmonitor.EdgeDailyRollups: + return m.cleareddaily_rollups + case channelmonitor.EdgeRequestTemplate: + return m.clearedrequest_template } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *GroupMutation) ClearEdge(name string) error { +func (m *ChannelMonitorMutation) ClearEdge(name string) error { switch name { + case channelmonitor.EdgeRequestTemplate: + m.ClearRequestTemplate() + return nil } - return fmt.Errorf("unknown Group unique edge %s", name) + return fmt.Errorf("unknown ChannelMonitor unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *GroupMutation) ResetEdge(name string) error { +func (m *ChannelMonitorMutation) ResetEdge(name string) error { switch name { - case group.EdgeAPIKeys: - m.ResetAPIKeys() - return nil - case group.EdgeRedeemCodes: - m.ResetRedeemCodes() + case channelmonitor.EdgeHistory: + m.ResetHistory() return nil - case group.EdgeSubscriptions: - m.ResetSubscriptions() - return nil - case group.EdgeUsageLogs: - m.ResetUsageLogs() - return nil - case group.EdgeAccounts: - m.ResetAccounts() + case channelmonitor.EdgeDailyRollups: + m.ResetDailyRollups() return nil - case group.EdgeAllowedUsers: - m.ResetAllowedUsers() + case channelmonitor.EdgeRequestTemplate: + m.ResetRequestTemplate() return nil } - return fmt.Errorf("unknown Group edge %s", name) + return fmt.Errorf("unknown ChannelMonitor edge %s", name) } -// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. -type IdempotencyRecordMutation struct { +// ChannelMonitorDailyRollupMutation represents an operation that mutates the ChannelMonitorDailyRollup nodes in the graph. +type ChannelMonitorDailyRollupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - scope *string - idempotency_key_hash *string - request_fingerprint *string - status *string - response_status *int - addresponse_status *int - response_body *string - error_reason *string - locked_until *time.Time - expires_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*IdempotencyRecord, error) - predicates []predicate.IdempotencyRecord -} - -var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) - -// idempotencyrecordOption allows management of the mutation configuration using functional options. -type idempotencyrecordOption func(*IdempotencyRecordMutation) - -// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. -func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { - m := &IdempotencyRecordMutation{ + op Op + typ string + id *int64 + model *string + bucket_date *time.Time + total_checks *int + addtotal_checks *int + ok_count *int + addok_count *int + operational_count *int + addoperational_count *int + degraded_count *int + adddegraded_count *int + failed_count *int + addfailed_count *int + error_count *int + adderror_count *int + sum_latency_ms *int64 + addsum_latency_ms *int64 + count_latency *int + addcount_latency *int + sum_ping_latency_ms *int64 + addsum_ping_latency_ms *int64 + count_ping_latency *int + addcount_ping_latency *int + computed_at *time.Time + clearedFields map[string]struct{} + monitor *int64 + clearedmonitor bool + done bool + oldValue func(context.Context) (*ChannelMonitorDailyRollup, error) + predicates []predicate.ChannelMonitorDailyRollup +} + +var _ ent.Mutation = (*ChannelMonitorDailyRollupMutation)(nil) + +// channelmonitordailyrollupOption allows management of the mutation configuration using functional options. +type channelmonitordailyrollupOption func(*ChannelMonitorDailyRollupMutation) + +// newChannelMonitorDailyRollupMutation creates new mutation for the ChannelMonitorDailyRollup entity. +func newChannelMonitorDailyRollupMutation(c config, op Op, opts ...channelmonitordailyrollupOption) *ChannelMonitorDailyRollupMutation { + m := &ChannelMonitorDailyRollupMutation{ config: c, op: op, - typ: TypeIdempotencyRecord, + typ: TypeChannelMonitorDailyRollup, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -11252,20 +10387,20 @@ func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOpti return m } -// withIdempotencyRecordID sets the ID field of the mutation. -func withIdempotencyRecordID(id int64) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { +// withChannelMonitorDailyRollupID sets the ID field of the mutation. +func withChannelMonitorDailyRollupID(id int64) channelmonitordailyrollupOption { + return func(m *ChannelMonitorDailyRollupMutation) { var ( err error once sync.Once - value *IdempotencyRecord + value *ChannelMonitorDailyRollup ) - m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + m.oldValue = func(ctx context.Context) (*ChannelMonitorDailyRollup, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().IdempotencyRecord.Get(ctx, id) + value, err = m.Client().ChannelMonitorDailyRollup.Get(ctx, id) } }) return value, err @@ -11274,10 +10409,10 @@ func withIdempotencyRecordID(id int64) idempotencyrecordOption { } } -// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. -func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { - m.oldValue = func(context.Context) (*IdempotencyRecord, error) { +// withChannelMonitorDailyRollup sets the old ChannelMonitorDailyRollup of the mutation. +func withChannelMonitorDailyRollup(node *ChannelMonitorDailyRollup) channelmonitordailyrollupOption { + return func(m *ChannelMonitorDailyRollupMutation) { + m.oldValue = func(context.Context) (*ChannelMonitorDailyRollup, error) { return node, nil } m.id = &node.ID @@ -11286,7 +10421,7 @@ func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m IdempotencyRecordMutation) Client() *Client { +func (m ChannelMonitorDailyRollupMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -11294,7 +10429,7 @@ func (m IdempotencyRecordMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m IdempotencyRecordMutation) Tx() (*Tx, error) { +func (m ChannelMonitorDailyRollupMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -11305,7 +10440,7 @@ func (m IdempotencyRecordMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { +func (m *ChannelMonitorDailyRollupMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -11316,7 +10451,7 @@ func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *ChannelMonitorDailyRollupMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -11325,490 +10460,752 @@ func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + return m.Client().ChannelMonitorDailyRollup.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetCreatedAt sets the "created_at" field. -func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetMonitorID sets the "monitor_id" field. +func (m *ChannelMonitorDailyRollupMutation) SetMonitorID(i int64) { + m.monitor = &i } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// MonitorID returns the value of the "monitor_id" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) MonitorID() (r int64, exists bool) { + v := m.monitor if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldMonitorID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldMonitorID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldMonitorID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldMonitorID: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.MonitorID, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *IdempotencyRecordMutation) ResetCreatedAt() { - m.created_at = nil +// ResetMonitorID resets all changes to the "monitor_id" field. +func (m *ChannelMonitorDailyRollupMutation) ResetMonitorID() { + m.monitor = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetModel sets the "model" field. +func (m *ChannelMonitorDailyRollupMutation) SetModel(s string) { + m.model = &s } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// Model returns the value of the "model" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) Model() (r string, exists bool) { + v := m.model if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldModel returns the old "model" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldModel(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldModel is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldModel requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldModel: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.Model, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *IdempotencyRecordMutation) ResetUpdatedAt() { - m.updated_at = nil +// ResetModel resets all changes to the "model" field. +func (m *ChannelMonitorDailyRollupMutation) ResetModel() { + m.model = nil } -// SetScope sets the "scope" field. -func (m *IdempotencyRecordMutation) SetScope(s string) { - m.scope = &s +// SetBucketDate sets the "bucket_date" field. +func (m *ChannelMonitorDailyRollupMutation) SetBucketDate(t time.Time) { + m.bucket_date = &t } -// Scope returns the value of the "scope" field in the mutation. -func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { - v := m.scope +// BucketDate returns the value of the "bucket_date" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) BucketDate() (r time.Time, exists bool) { + v := m.bucket_date if v == nil { return } return *v, true } -// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldBucketDate returns the old "bucket_date" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldBucketDate(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldScope is only allowed on UpdateOne operations") + return v, errors.New("OldBucketDate is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldScope requires an ID field in the mutation") + return v, errors.New("OldBucketDate requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldScope: %w", err) + return v, fmt.Errorf("querying old value for OldBucketDate: %w", err) } - return oldValue.Scope, nil + return oldValue.BucketDate, nil } -// ResetScope resets all changes to the "scope" field. -func (m *IdempotencyRecordMutation) ResetScope() { - m.scope = nil +// ResetBucketDate resets all changes to the "bucket_date" field. +func (m *ChannelMonitorDailyRollupMutation) ResetBucketDate() { + m.bucket_date = nil } -// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { - m.idempotency_key_hash = &s +// SetTotalChecks sets the "total_checks" field. +func (m *ChannelMonitorDailyRollupMutation) SetTotalChecks(i int) { + m.total_checks = &i + m.addtotal_checks = nil } -// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. -func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { - v := m.idempotency_key_hash +// TotalChecks returns the value of the "total_checks" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) TotalChecks() (r int, exists bool) { + v := m.total_checks if v == nil { return } return *v, true } -// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldTotalChecks returns the old "total_checks" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldTotalChecks(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + return v, errors.New("OldTotalChecks is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + return v, errors.New("OldTotalChecks requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + return v, fmt.Errorf("querying old value for OldTotalChecks: %w", err) } - return oldValue.IdempotencyKeyHash, nil + return oldValue.TotalChecks, nil } -// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { - m.idempotency_key_hash = nil +// AddTotalChecks adds i to the "total_checks" field. +func (m *ChannelMonitorDailyRollupMutation) AddTotalChecks(i int) { + if m.addtotal_checks != nil { + *m.addtotal_checks += i + } else { + m.addtotal_checks = &i + } } -// SetRequestFingerprint sets the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { - m.request_fingerprint = &s +// AddedTotalChecks returns the value that was added to the "total_checks" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedTotalChecks() (r int, exists bool) { + v := m.addtotal_checks + if v == nil { + return + } + return *v, true } -// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. -func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { - v := m.request_fingerprint +// ResetTotalChecks resets all changes to the "total_checks" field. +func (m *ChannelMonitorDailyRollupMutation) ResetTotalChecks() { + m.total_checks = nil + m.addtotal_checks = nil +} + +// SetOkCount sets the "ok_count" field. +func (m *ChannelMonitorDailyRollupMutation) SetOkCount(i int) { + m.ok_count = &i + m.addok_count = nil +} + +// OkCount returns the value of the "ok_count" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) OkCount() (r int, exists bool) { + v := m.ok_count if v == nil { return } return *v, true } -// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldOkCount returns the old "ok_count" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldOkCount(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + return v, errors.New("OldOkCount is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + return v, errors.New("OldOkCount requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + return v, fmt.Errorf("querying old value for OldOkCount: %w", err) } - return oldValue.RequestFingerprint, nil + return oldValue.OkCount, nil } -// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { - m.request_fingerprint = nil +// AddOkCount adds i to the "ok_count" field. +func (m *ChannelMonitorDailyRollupMutation) AddOkCount(i int) { + if m.addok_count != nil { + *m.addok_count += i + } else { + m.addok_count = &i + } } -// SetStatus sets the "status" field. -func (m *IdempotencyRecordMutation) SetStatus(s string) { - m.status = &s +// AddedOkCount returns the value that was added to the "ok_count" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedOkCount() (r int, exists bool) { + v := m.addok_count + if v == nil { + return + } + return *v, true } -// Status returns the value of the "status" field in the mutation. -func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { - v := m.status +// ResetOkCount resets all changes to the "ok_count" field. +func (m *ChannelMonitorDailyRollupMutation) ResetOkCount() { + m.ok_count = nil + m.addok_count = nil +} + +// SetOperationalCount sets the "operational_count" field. +func (m *ChannelMonitorDailyRollupMutation) SetOperationalCount(i int) { + m.operational_count = &i + m.addoperational_count = nil +} + +// OperationalCount returns the value of the "operational_count" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) OperationalCount() (r int, exists bool) { + v := m.operational_count if v == nil { return } return *v, true } -// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldOperationalCount returns the old "operational_count" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldOperationalCount(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") + return v, errors.New("OldOperationalCount is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + return v, errors.New("OldOperationalCount requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + return v, fmt.Errorf("querying old value for OldOperationalCount: %w", err) } - return oldValue.Status, nil + return oldValue.OperationalCount, nil } -// ResetStatus resets all changes to the "status" field. -func (m *IdempotencyRecordMutation) ResetStatus() { - m.status = nil +// AddOperationalCount adds i to the "operational_count" field. +func (m *ChannelMonitorDailyRollupMutation) AddOperationalCount(i int) { + if m.addoperational_count != nil { + *m.addoperational_count += i + } else { + m.addoperational_count = &i + } } -// SetResponseStatus sets the "response_status" field. -func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { - m.response_status = &i - m.addresponse_status = nil +// AddedOperationalCount returns the value that was added to the "operational_count" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedOperationalCount() (r int, exists bool) { + v := m.addoperational_count + if v == nil { + return + } + return *v, true } -// ResponseStatus returns the value of the "response_status" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { - v := m.response_status +// ResetOperationalCount resets all changes to the "operational_count" field. +func (m *ChannelMonitorDailyRollupMutation) ResetOperationalCount() { + m.operational_count = nil + m.addoperational_count = nil +} + +// SetDegradedCount sets the "degraded_count" field. +func (m *ChannelMonitorDailyRollupMutation) SetDegradedCount(i int) { + m.degraded_count = &i + m.adddegraded_count = nil +} + +// DegradedCount returns the value of the "degraded_count" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) DegradedCount() (r int, exists bool) { + v := m.degraded_count if v == nil { return } return *v, true } -// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldDegradedCount returns the old "degraded_count" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldDegradedCount(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + return v, errors.New("OldDegradedCount is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseStatus requires an ID field in the mutation") + return v, errors.New("OldDegradedCount requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + return v, fmt.Errorf("querying old value for OldDegradedCount: %w", err) } - return oldValue.ResponseStatus, nil + return oldValue.DegradedCount, nil } -// AddResponseStatus adds i to the "response_status" field. -func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { - if m.addresponse_status != nil { - *m.addresponse_status += i +// AddDegradedCount adds i to the "degraded_count" field. +func (m *ChannelMonitorDailyRollupMutation) AddDegradedCount(i int) { + if m.adddegraded_count != nil { + *m.adddegraded_count += i } else { - m.addresponse_status = &i + m.adddegraded_count = &i } } -// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. -func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { - v := m.addresponse_status +// AddedDegradedCount returns the value that was added to the "degraded_count" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedDegradedCount() (r int, exists bool) { + v := m.adddegraded_count if v == nil { return } return *v, true } -// ClearResponseStatus clears the value of the "response_status" field. -func (m *IdempotencyRecordMutation) ClearResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} -} - -// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] - return ok -} - -// ResetResponseStatus resets all changes to the "response_status" field. -func (m *IdempotencyRecordMutation) ResetResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +// ResetDegradedCount resets all changes to the "degraded_count" field. +func (m *ChannelMonitorDailyRollupMutation) ResetDegradedCount() { + m.degraded_count = nil + m.adddegraded_count = nil } -// SetResponseBody sets the "response_body" field. -func (m *IdempotencyRecordMutation) SetResponseBody(s string) { - m.response_body = &s +// SetFailedCount sets the "failed_count" field. +func (m *ChannelMonitorDailyRollupMutation) SetFailedCount(i int) { + m.failed_count = &i + m.addfailed_count = nil } -// ResponseBody returns the value of the "response_body" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { - v := m.response_body +// FailedCount returns the value of the "failed_count" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) FailedCount() (r int, exists bool) { + v := m.failed_count if v == nil { return } return *v, true } -// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldFailedCount returns the old "failed_count" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldFailedCount(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + return v, errors.New("OldFailedCount is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseBody requires an ID field in the mutation") + return v, errors.New("OldFailedCount requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + return v, fmt.Errorf("querying old value for OldFailedCount: %w", err) } - return oldValue.ResponseBody, nil + return oldValue.FailedCount, nil } -// ClearResponseBody clears the value of the "response_body" field. -func (m *IdempotencyRecordMutation) ClearResponseBody() { - m.response_body = nil - m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +// AddFailedCount adds i to the "failed_count" field. +func (m *ChannelMonitorDailyRollupMutation) AddFailedCount(i int) { + if m.addfailed_count != nil { + *m.addfailed_count += i + } else { + m.addfailed_count = &i + } } -// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] - return ok +// AddedFailedCount returns the value that was added to the "failed_count" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedFailedCount() (r int, exists bool) { + v := m.addfailed_count + if v == nil { + return + } + return *v, true } -// ResetResponseBody resets all changes to the "response_body" field. -func (m *IdempotencyRecordMutation) ResetResponseBody() { - m.response_body = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +// ResetFailedCount resets all changes to the "failed_count" field. +func (m *ChannelMonitorDailyRollupMutation) ResetFailedCount() { + m.failed_count = nil + m.addfailed_count = nil } -// SetErrorReason sets the "error_reason" field. -func (m *IdempotencyRecordMutation) SetErrorReason(s string) { - m.error_reason = &s +// SetErrorCount sets the "error_count" field. +func (m *ChannelMonitorDailyRollupMutation) SetErrorCount(i int) { + m.error_count = &i + m.adderror_count = nil } -// ErrorReason returns the value of the "error_reason" field in the mutation. -func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { - v := m.error_reason +// ErrorCount returns the value of the "error_count" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) ErrorCount() (r int, exists bool) { + v := m.error_count if v == nil { return } return *v, true } -// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldErrorCount returns the old "error_count" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldErrorCount(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + return v, errors.New("OldErrorCount is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorReason requires an ID field in the mutation") + return v, errors.New("OldErrorCount requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + return v, fmt.Errorf("querying old value for OldErrorCount: %w", err) } - return oldValue.ErrorReason, nil + return oldValue.ErrorCount, nil } -// ClearErrorReason clears the value of the "error_reason" field. -func (m *IdempotencyRecordMutation) ClearErrorReason() { - m.error_reason = nil - m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +// AddErrorCount adds i to the "error_count" field. +func (m *ChannelMonitorDailyRollupMutation) AddErrorCount(i int) { + if m.adderror_count != nil { + *m.adderror_count += i + } else { + m.adderror_count = &i + } } -// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] - return ok +// AddedErrorCount returns the value that was added to the "error_count" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedErrorCount() (r int, exists bool) { + v := m.adderror_count + if v == nil { + return + } + return *v, true } -// ResetErrorReason resets all changes to the "error_reason" field. -func (m *IdempotencyRecordMutation) ResetErrorReason() { - m.error_reason = nil - delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +// ResetErrorCount resets all changes to the "error_count" field. +func (m *ChannelMonitorDailyRollupMutation) ResetErrorCount() { + m.error_count = nil + m.adderror_count = nil } -// SetLockedUntil sets the "locked_until" field. -func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { - m.locked_until = &t +// SetSumLatencyMs sets the "sum_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) SetSumLatencyMs(i int64) { + m.sum_latency_ms = &i + m.addsum_latency_ms = nil } -// LockedUntil returns the value of the "locked_until" field in the mutation. -func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { - v := m.locked_until +// SumLatencyMs returns the value of the "sum_latency_ms" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) SumLatencyMs() (r int64, exists bool) { + v := m.sum_latency_ms if v == nil { return } return *v, true } -// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldSumLatencyMs returns the old "sum_latency_ms" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldSumLatencyMs(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + return v, errors.New("OldSumLatencyMs is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLockedUntil requires an ID field in the mutation") + return v, errors.New("OldSumLatencyMs requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + return v, fmt.Errorf("querying old value for OldSumLatencyMs: %w", err) } - return oldValue.LockedUntil, nil + return oldValue.SumLatencyMs, nil } -// ClearLockedUntil clears the value of the "locked_until" field. -func (m *IdempotencyRecordMutation) ClearLockedUntil() { - m.locked_until = nil - m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +// AddSumLatencyMs adds i to the "sum_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) AddSumLatencyMs(i int64) { + if m.addsum_latency_ms != nil { + *m.addsum_latency_ms += i + } else { + m.addsum_latency_ms = &i + } } -// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] - return ok +// AddedSumLatencyMs returns the value that was added to the "sum_latency_ms" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedSumLatencyMs() (r int64, exists bool) { + v := m.addsum_latency_ms + if v == nil { + return + } + return *v, true } -// ResetLockedUntil resets all changes to the "locked_until" field. -func (m *IdempotencyRecordMutation) ResetLockedUntil() { - m.locked_until = nil - delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +// ResetSumLatencyMs resets all changes to the "sum_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) ResetSumLatencyMs() { + m.sum_latency_ms = nil + m.addsum_latency_ms = nil } -// SetExpiresAt sets the "expires_at" field. -func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// SetCountLatency sets the "count_latency" field. +func (m *ChannelMonitorDailyRollupMutation) SetCountLatency(i int) { + m.count_latency = &i + m.addcount_latency = nil } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at +// CountLatency returns the value of the "count_latency" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) CountLatency() (r int, exists bool) { + v := m.count_latency if v == nil { return } return *v, true } -// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldCountLatency returns the old "count_latency" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { +func (m *ChannelMonitorDailyRollupMutation) OldCountLatency(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + return v, errors.New("OldCountLatency is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + return v, errors.New("OldCountLatency requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + return v, fmt.Errorf("querying old value for OldCountLatency: %w", err) } - return oldValue.ExpiresAt, nil + return oldValue.CountLatency, nil } -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *IdempotencyRecordMutation) ResetExpiresAt() { - m.expires_at = nil +// AddCountLatency adds i to the "count_latency" field. +func (m *ChannelMonitorDailyRollupMutation) AddCountLatency(i int) { + if m.addcount_latency != nil { + *m.addcount_latency += i + } else { + m.addcount_latency = &i + } } -// Where appends a list predicates to the IdempotencyRecordMutation builder. -func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { +// AddedCountLatency returns the value that was added to the "count_latency" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedCountLatency() (r int, exists bool) { + v := m.addcount_latency + if v == nil { + return + } + return *v, true +} + +// ResetCountLatency resets all changes to the "count_latency" field. +func (m *ChannelMonitorDailyRollupMutation) ResetCountLatency() { + m.count_latency = nil + m.addcount_latency = nil +} + +// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) SetSumPingLatencyMs(i int64) { + m.sum_ping_latency_ms = &i + m.addsum_ping_latency_ms = nil +} + +// SumPingLatencyMs returns the value of the "sum_ping_latency_ms" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) SumPingLatencyMs() (r int64, exists bool) { + v := m.sum_ping_latency_ms + if v == nil { + return + } + return *v, true +} + +// OldSumPingLatencyMs returns the old "sum_ping_latency_ms" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorDailyRollupMutation) OldSumPingLatencyMs(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSumPingLatencyMs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSumPingLatencyMs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSumPingLatencyMs: %w", err) + } + return oldValue.SumPingLatencyMs, nil +} + +// AddSumPingLatencyMs adds i to the "sum_ping_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) AddSumPingLatencyMs(i int64) { + if m.addsum_ping_latency_ms != nil { + *m.addsum_ping_latency_ms += i + } else { + m.addsum_ping_latency_ms = &i + } +} + +// AddedSumPingLatencyMs returns the value that was added to the "sum_ping_latency_ms" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedSumPingLatencyMs() (r int64, exists bool) { + v := m.addsum_ping_latency_ms + if v == nil { + return + } + return *v, true +} + +// ResetSumPingLatencyMs resets all changes to the "sum_ping_latency_ms" field. +func (m *ChannelMonitorDailyRollupMutation) ResetSumPingLatencyMs() { + m.sum_ping_latency_ms = nil + m.addsum_ping_latency_ms = nil +} + +// SetCountPingLatency sets the "count_ping_latency" field. +func (m *ChannelMonitorDailyRollupMutation) SetCountPingLatency(i int) { + m.count_ping_latency = &i + m.addcount_ping_latency = nil +} + +// CountPingLatency returns the value of the "count_ping_latency" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) CountPingLatency() (r int, exists bool) { + v := m.count_ping_latency + if v == nil { + return + } + return *v, true +} + +// OldCountPingLatency returns the old "count_ping_latency" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorDailyRollupMutation) OldCountPingLatency(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCountPingLatency is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCountPingLatency requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCountPingLatency: %w", err) + } + return oldValue.CountPingLatency, nil +} + +// AddCountPingLatency adds i to the "count_ping_latency" field. +func (m *ChannelMonitorDailyRollupMutation) AddCountPingLatency(i int) { + if m.addcount_ping_latency != nil { + *m.addcount_ping_latency += i + } else { + m.addcount_ping_latency = &i + } +} + +// AddedCountPingLatency returns the value that was added to the "count_ping_latency" field in this mutation. +func (m *ChannelMonitorDailyRollupMutation) AddedCountPingLatency() (r int, exists bool) { + v := m.addcount_ping_latency + if v == nil { + return + } + return *v, true +} + +// ResetCountPingLatency resets all changes to the "count_ping_latency" field. +func (m *ChannelMonitorDailyRollupMutation) ResetCountPingLatency() { + m.count_ping_latency = nil + m.addcount_ping_latency = nil +} + +// SetComputedAt sets the "computed_at" field. +func (m *ChannelMonitorDailyRollupMutation) SetComputedAt(t time.Time) { + m.computed_at = &t +} + +// ComputedAt returns the value of the "computed_at" field in the mutation. +func (m *ChannelMonitorDailyRollupMutation) ComputedAt() (r time.Time, exists bool) { + v := m.computed_at + if v == nil { + return + } + return *v, true +} + +// OldComputedAt returns the old "computed_at" field's value of the ChannelMonitorDailyRollup entity. +// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorDailyRollupMutation) OldComputedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldComputedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldComputedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldComputedAt: %w", err) + } + return oldValue.ComputedAt, nil +} + +// ResetComputedAt resets all changes to the "computed_at" field. +func (m *ChannelMonitorDailyRollupMutation) ResetComputedAt() { + m.computed_at = nil +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (m *ChannelMonitorDailyRollupMutation) ClearMonitor() { + m.clearedmonitor = true + m.clearedFields[channelmonitordailyrollup.FieldMonitorID] = struct{}{} +} + +// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared. +func (m *ChannelMonitorDailyRollupMutation) MonitorCleared() bool { + return m.clearedmonitor +} + +// MonitorIDs returns the "monitor" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// MonitorID instead. It exists only for internal usage by the builders. +func (m *ChannelMonitorDailyRollupMutation) MonitorIDs() (ids []int64) { + if id := m.monitor; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetMonitor resets all changes to the "monitor" edge. +func (m *ChannelMonitorDailyRollupMutation) ResetMonitor() { + m.monitor = nil + m.clearedmonitor = false +} + +// Where appends a list predicates to the ChannelMonitorDailyRollupMutation builder. +func (m *ChannelMonitorDailyRollupMutation) Where(ps ...predicate.ChannelMonitorDailyRollup) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// WhereP appends storage-level predicates to the ChannelMonitorDailyRollupMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.IdempotencyRecord, len(ps)) +func (m *ChannelMonitorDailyRollupMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ChannelMonitorDailyRollup, len(ps)) for i := range ps { p[i] = ps[i] } @@ -11816,57 +11213,66 @@ func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *IdempotencyRecordMutation) Op() Op { +func (m *ChannelMonitorDailyRollupMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *IdempotencyRecordMutation) SetOp(op Op) { +func (m *ChannelMonitorDailyRollupMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (IdempotencyRecord). -func (m *IdempotencyRecordMutation) Type() string { +// Type returns the node type of this mutation (ChannelMonitorDailyRollup). +func (m *ChannelMonitorDailyRollupMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *IdempotencyRecordMutation) Fields() []string { - fields := make([]string, 0, 11) - if m.created_at != nil { - fields = append(fields, idempotencyrecord.FieldCreatedAt) +func (m *ChannelMonitorDailyRollupMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.monitor != nil { + fields = append(fields, channelmonitordailyrollup.FieldMonitorID) } - if m.updated_at != nil { - fields = append(fields, idempotencyrecord.FieldUpdatedAt) + if m.model != nil { + fields = append(fields, channelmonitordailyrollup.FieldModel) } - if m.scope != nil { - fields = append(fields, idempotencyrecord.FieldScope) + if m.bucket_date != nil { + fields = append(fields, channelmonitordailyrollup.FieldBucketDate) } - if m.idempotency_key_hash != nil { - fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + if m.total_checks != nil { + fields = append(fields, channelmonitordailyrollup.FieldTotalChecks) } - if m.request_fingerprint != nil { - fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + if m.ok_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldOkCount) } - if m.status != nil { - fields = append(fields, idempotencyrecord.FieldStatus) + if m.operational_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldOperationalCount) } - if m.response_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.degraded_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldDegradedCount) } - if m.response_body != nil { - fields = append(fields, idempotencyrecord.FieldResponseBody) + if m.failed_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldFailedCount) } - if m.error_reason != nil { - fields = append(fields, idempotencyrecord.FieldErrorReason) + if m.error_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldErrorCount) } - if m.locked_until != nil { - fields = append(fields, idempotencyrecord.FieldLockedUntil) + if m.sum_latency_ms != nil { + fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs) } - if m.expires_at != nil { - fields = append(fields, idempotencyrecord.FieldExpiresAt) + if m.count_latency != nil { + fields = append(fields, channelmonitordailyrollup.FieldCountLatency) + } + if m.sum_ping_latency_ms != nil { + fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs) + } + if m.count_ping_latency != nil { + fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency) + } + if m.computed_at != nil { + fields = append(fields, channelmonitordailyrollup.FieldComputedAt) } return fields } @@ -11874,30 +11280,36 @@ func (m *IdempotencyRecordMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { +func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool) { switch name { - case idempotencyrecord.FieldCreatedAt: - return m.CreatedAt() - case idempotencyrecord.FieldUpdatedAt: - return m.UpdatedAt() - case idempotencyrecord.FieldScope: - return m.Scope() - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.IdempotencyKeyHash() - case idempotencyrecord.FieldRequestFingerprint: - return m.RequestFingerprint() - case idempotencyrecord.FieldStatus: - return m.Status() - case idempotencyrecord.FieldResponseStatus: - return m.ResponseStatus() - case idempotencyrecord.FieldResponseBody: - return m.ResponseBody() - case idempotencyrecord.FieldErrorReason: - return m.ErrorReason() - case idempotencyrecord.FieldLockedUntil: - return m.LockedUntil() - case idempotencyrecord.FieldExpiresAt: - return m.ExpiresAt() + case channelmonitordailyrollup.FieldMonitorID: + return m.MonitorID() + case channelmonitordailyrollup.FieldModel: + return m.Model() + case channelmonitordailyrollup.FieldBucketDate: + return m.BucketDate() + case channelmonitordailyrollup.FieldTotalChecks: + return m.TotalChecks() + case channelmonitordailyrollup.FieldOkCount: + return m.OkCount() + case channelmonitordailyrollup.FieldOperationalCount: + return m.OperationalCount() + case channelmonitordailyrollup.FieldDegradedCount: + return m.DegradedCount() + case channelmonitordailyrollup.FieldFailedCount: + return m.FailedCount() + case channelmonitordailyrollup.FieldErrorCount: + return m.ErrorCount() + case channelmonitordailyrollup.FieldSumLatencyMs: + return m.SumLatencyMs() + case channelmonitordailyrollup.FieldCountLatency: + return m.CountLatency() + case channelmonitordailyrollup.FieldSumPingLatencyMs: + return m.SumPingLatencyMs() + case channelmonitordailyrollup.FieldCountPingLatency: + return m.CountPingLatency() + case channelmonitordailyrollup.FieldComputedAt: + return m.ComputedAt() } return nil, false } @@ -11905,126 +11317,180 @@ func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case idempotencyrecord.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case idempotencyrecord.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case idempotencyrecord.FieldScope: - return m.OldScope(ctx) - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.OldIdempotencyKeyHash(ctx) - case idempotencyrecord.FieldRequestFingerprint: - return m.OldRequestFingerprint(ctx) - case idempotencyrecord.FieldStatus: - return m.OldStatus(ctx) - case idempotencyrecord.FieldResponseStatus: - return m.OldResponseStatus(ctx) - case idempotencyrecord.FieldResponseBody: - return m.OldResponseBody(ctx) - case idempotencyrecord.FieldErrorReason: - return m.OldErrorReason(ctx) - case idempotencyrecord.FieldLockedUntil: - return m.OldLockedUntil(ctx) - case idempotencyrecord.FieldExpiresAt: - return m.OldExpiresAt(ctx) - } - return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) + case channelmonitordailyrollup.FieldMonitorID: + return m.OldMonitorID(ctx) + case channelmonitordailyrollup.FieldModel: + return m.OldModel(ctx) + case channelmonitordailyrollup.FieldBucketDate: + return m.OldBucketDate(ctx) + case channelmonitordailyrollup.FieldTotalChecks: + return m.OldTotalChecks(ctx) + case channelmonitordailyrollup.FieldOkCount: + return m.OldOkCount(ctx) + case channelmonitordailyrollup.FieldOperationalCount: + return m.OldOperationalCount(ctx) + case channelmonitordailyrollup.FieldDegradedCount: + return m.OldDegradedCount(ctx) + case channelmonitordailyrollup.FieldFailedCount: + return m.OldFailedCount(ctx) + case channelmonitordailyrollup.FieldErrorCount: + return m.OldErrorCount(ctx) + case channelmonitordailyrollup.FieldSumLatencyMs: + return m.OldSumLatencyMs(ctx) + case channelmonitordailyrollup.FieldCountLatency: + return m.OldCountLatency(ctx) + case channelmonitordailyrollup.FieldSumPingLatencyMs: + return m.OldSumPingLatencyMs(ctx) + case channelmonitordailyrollup.FieldCountPingLatency: + return m.OldCountPingLatency(ctx) + case channelmonitordailyrollup.FieldComputedAt: + return m.OldComputedAt(ctx) + } + return nil, fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { +func (m *ChannelMonitorDailyRollupMutation) SetField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldCreatedAt: - v, ok := value.(time.Time) + case channelmonitordailyrollup.FieldMonitorID: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetMonitorID(v) return nil - case idempotencyrecord.FieldUpdatedAt: + case channelmonitordailyrollup.FieldModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModel(v) + return nil + case channelmonitordailyrollup.FieldBucketDate: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetBucketDate(v) return nil - case idempotencyrecord.FieldScope: - v, ok := value.(string) + case channelmonitordailyrollup.FieldTotalChecks: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetScope(v) + m.SetTotalChecks(v) return nil - case idempotencyrecord.FieldIdempotencyKeyHash: - v, ok := value.(string) + case channelmonitordailyrollup.FieldOkCount: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetIdempotencyKeyHash(v) + m.SetOkCount(v) return nil - case idempotencyrecord.FieldRequestFingerprint: - v, ok := value.(string) + case channelmonitordailyrollup.FieldOperationalCount: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRequestFingerprint(v) + m.SetOperationalCount(v) return nil - case idempotencyrecord.FieldStatus: - v, ok := value.(string) + case channelmonitordailyrollup.FieldDegradedCount: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetDegradedCount(v) return nil - case idempotencyrecord.FieldResponseStatus: + case channelmonitordailyrollup.FieldFailedCount: v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseStatus(v) + m.SetFailedCount(v) return nil - case idempotencyrecord.FieldResponseBody: - v, ok := value.(string) + case channelmonitordailyrollup.FieldErrorCount: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseBody(v) + m.SetErrorCount(v) return nil - case idempotencyrecord.FieldErrorReason: - v, ok := value.(string) + case channelmonitordailyrollup.FieldSumLatencyMs: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetErrorReason(v) + m.SetSumLatencyMs(v) return nil - case idempotencyrecord.FieldLockedUntil: - v, ok := value.(time.Time) + case channelmonitordailyrollup.FieldCountLatency: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLockedUntil(v) + m.SetCountLatency(v) return nil - case idempotencyrecord.FieldExpiresAt: + case channelmonitordailyrollup.FieldSumPingLatencyMs: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSumPingLatencyMs(v) + return nil + case channelmonitordailyrollup.FieldCountPingLatency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCountPingLatency(v) + return nil + case channelmonitordailyrollup.FieldComputedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetExpiresAt(v) + m.SetComputedAt(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *IdempotencyRecordMutation) AddedFields() []string { +func (m *ChannelMonitorDailyRollupMutation) AddedFields() []string { var fields []string - if m.addresponse_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.addtotal_checks != nil { + fields = append(fields, channelmonitordailyrollup.FieldTotalChecks) + } + if m.addok_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldOkCount) + } + if m.addoperational_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldOperationalCount) + } + if m.adddegraded_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldDegradedCount) + } + if m.addfailed_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldFailedCount) + } + if m.adderror_count != nil { + fields = append(fields, channelmonitordailyrollup.FieldErrorCount) + } + if m.addsum_latency_ms != nil { + fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs) + } + if m.addcount_latency != nil { + fields = append(fields, channelmonitordailyrollup.FieldCountLatency) + } + if m.addsum_ping_latency_ms != nil { + fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs) + } + if m.addcount_ping_latency != nil { + fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency) } return fields } @@ -12032,10 +11498,28 @@ func (m *IdempotencyRecordMutation) AddedFields() []string { // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { +func (m *ChannelMonitorDailyRollupMutation) AddedField(name string) (ent.Value, bool) { switch name { - case idempotencyrecord.FieldResponseStatus: - return m.AddedResponseStatus() + case channelmonitordailyrollup.FieldTotalChecks: + return m.AddedTotalChecks() + case channelmonitordailyrollup.FieldOkCount: + return m.AddedOkCount() + case channelmonitordailyrollup.FieldOperationalCount: + return m.AddedOperationalCount() + case channelmonitordailyrollup.FieldDegradedCount: + return m.AddedDegradedCount() + case channelmonitordailyrollup.FieldFailedCount: + return m.AddedFailedCount() + case channelmonitordailyrollup.FieldErrorCount: + return m.AddedErrorCount() + case channelmonitordailyrollup.FieldSumLatencyMs: + return m.AddedSumLatencyMs() + case channelmonitordailyrollup.FieldCountLatency: + return m.AddedCountLatency() + case channelmonitordailyrollup.FieldSumPingLatencyMs: + return m.AddedSumPingLatencyMs() + case channelmonitordailyrollup.FieldCountPingLatency: + return m.AddedCountPingLatency() } return nil, false } @@ -12043,182 +11527,258 @@ func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { +func (m *ChannelMonitorDailyRollupMutation) AddField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldResponseStatus: + case channelmonitordailyrollup.FieldTotalChecks: v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddResponseStatus(v) + m.AddTotalChecks(v) + return nil + case channelmonitordailyrollup.FieldOkCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOkCount(v) + return nil + case channelmonitordailyrollup.FieldOperationalCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOperationalCount(v) + return nil + case channelmonitordailyrollup.FieldDegradedCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDegradedCount(v) + return nil + case channelmonitordailyrollup.FieldFailedCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFailedCount(v) + return nil + case channelmonitordailyrollup.FieldErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddErrorCount(v) + return nil + case channelmonitordailyrollup.FieldSumLatencyMs: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSumLatencyMs(v) + return nil + case channelmonitordailyrollup.FieldCountLatency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCountLatency(v) + return nil + case channelmonitordailyrollup.FieldSumPingLatencyMs: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSumPingLatencyMs(v) + return nil + case channelmonitordailyrollup.FieldCountPingLatency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCountPingLatency(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) + return fmt.Errorf("unknown ChannelMonitorDailyRollup numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *IdempotencyRecordMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { - fields = append(fields, idempotencyrecord.FieldResponseStatus) - } - if m.FieldCleared(idempotencyrecord.FieldResponseBody) { - fields = append(fields, idempotencyrecord.FieldResponseBody) - } - if m.FieldCleared(idempotencyrecord.FieldErrorReason) { - fields = append(fields, idempotencyrecord.FieldErrorReason) - } - if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { - fields = append(fields, idempotencyrecord.FieldLockedUntil) - } - return fields +func (m *ChannelMonitorDailyRollupMutation) ClearedFields() []string { + return nil } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { +func (m *ChannelMonitorDailyRollupMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearField(name string) error { - switch name { - case idempotencyrecord.FieldResponseStatus: - m.ClearResponseStatus() - return nil - case idempotencyrecord.FieldResponseBody: - m.ClearResponseBody() - return nil - case idempotencyrecord.FieldErrorReason: - m.ClearErrorReason() - return nil - case idempotencyrecord.FieldLockedUntil: - m.ClearLockedUntil() - return nil - } - return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error { + return fmt.Errorf("unknown ChannelMonitorDailyRollup nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetField(name string) error { +func (m *ChannelMonitorDailyRollupMutation) ResetField(name string) error { switch name { - case idempotencyrecord.FieldCreatedAt: - m.ResetCreatedAt() + case channelmonitordailyrollup.FieldMonitorID: + m.ResetMonitorID() return nil - case idempotencyrecord.FieldUpdatedAt: - m.ResetUpdatedAt() + case channelmonitordailyrollup.FieldModel: + m.ResetModel() return nil - case idempotencyrecord.FieldScope: - m.ResetScope() + case channelmonitordailyrollup.FieldBucketDate: + m.ResetBucketDate() return nil - case idempotencyrecord.FieldIdempotencyKeyHash: - m.ResetIdempotencyKeyHash() + case channelmonitordailyrollup.FieldTotalChecks: + m.ResetTotalChecks() return nil - case idempotencyrecord.FieldRequestFingerprint: - m.ResetRequestFingerprint() + case channelmonitordailyrollup.FieldOkCount: + m.ResetOkCount() return nil - case idempotencyrecord.FieldStatus: - m.ResetStatus() + case channelmonitordailyrollup.FieldOperationalCount: + m.ResetOperationalCount() return nil - case idempotencyrecord.FieldResponseStatus: - m.ResetResponseStatus() + case channelmonitordailyrollup.FieldDegradedCount: + m.ResetDegradedCount() return nil - case idempotencyrecord.FieldResponseBody: - m.ResetResponseBody() + case channelmonitordailyrollup.FieldFailedCount: + m.ResetFailedCount() return nil - case idempotencyrecord.FieldErrorReason: - m.ResetErrorReason() + case channelmonitordailyrollup.FieldErrorCount: + m.ResetErrorCount() return nil - case idempotencyrecord.FieldLockedUntil: - m.ResetLockedUntil() + case channelmonitordailyrollup.FieldSumLatencyMs: + m.ResetSumLatencyMs() return nil - case idempotencyrecord.FieldExpiresAt: - m.ResetExpiresAt() + case channelmonitordailyrollup.FieldCountLatency: + m.ResetCountLatency() + return nil + case channelmonitordailyrollup.FieldSumPingLatencyMs: + m.ResetSumPingLatencyMs() + return nil + case channelmonitordailyrollup.FieldCountPingLatency: + m.ResetCountPingLatency() + return nil + case channelmonitordailyrollup.FieldComputedAt: + m.ResetComputedAt() return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *IdempotencyRecordMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorDailyRollupMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.monitor != nil { + edges = append(edges, channelmonitordailyrollup.EdgeMonitor) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { +func (m *ChannelMonitorDailyRollupMutation) AddedIDs(name string) []ent.Value { + switch name { + case channelmonitordailyrollup.EdgeMonitor: + if id := m.monitor; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *IdempotencyRecordMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorDailyRollupMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { +func (m *ChannelMonitorDailyRollupMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *IdempotencyRecordMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorDailyRollupMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedmonitor { + edges = append(edges, channelmonitordailyrollup.EdgeMonitor) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { +func (m *ChannelMonitorDailyRollupMutation) EdgeCleared(name string) bool { + switch name { + case channelmonitordailyrollup.EdgeMonitor: + return m.clearedmonitor + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +func (m *ChannelMonitorDailyRollupMutation) ClearEdge(name string) error { + switch name { + case channelmonitordailyrollup.EdgeMonitor: + m.ClearMonitor() + return nil + } + return fmt.Errorf("unknown ChannelMonitorDailyRollup unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +func (m *ChannelMonitorDailyRollupMutation) ResetEdge(name string) error { + switch name { + case channelmonitordailyrollup.EdgeMonitor: + m.ResetMonitor() + return nil + } + return fmt.Errorf("unknown ChannelMonitorDailyRollup edge %s", name) } -// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. -type PaymentAuditLogMutation struct { +// ChannelMonitorHistoryMutation represents an operation that mutates the ChannelMonitorHistory nodes in the graph. +type ChannelMonitorHistoryMutation struct { config - op Op - typ string - id *int64 - order_id *string - action *string - detail *string - operator *string - created_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentAuditLog, error) - predicates []predicate.PaymentAuditLog + op Op + typ string + id *int64 + model *string + status *channelmonitorhistory.Status + latency_ms *int + addlatency_ms *int + ping_latency_ms *int + addping_latency_ms *int + message *string + checked_at *time.Time + clearedFields map[string]struct{} + monitor *int64 + clearedmonitor bool + done bool + oldValue func(context.Context) (*ChannelMonitorHistory, error) + predicates []predicate.ChannelMonitorHistory } -var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) +var _ ent.Mutation = (*ChannelMonitorHistoryMutation)(nil) -// paymentauditlogOption allows management of the mutation configuration using functional options. -type paymentauditlogOption func(*PaymentAuditLogMutation) +// channelmonitorhistoryOption allows management of the mutation configuration using functional options. +type channelmonitorhistoryOption func(*ChannelMonitorHistoryMutation) -// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. -func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { - m := &PaymentAuditLogMutation{ +// newChannelMonitorHistoryMutation creates new mutation for the ChannelMonitorHistory entity. +func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhistoryOption) *ChannelMonitorHistoryMutation { + m := &ChannelMonitorHistoryMutation{ config: c, op: op, - typ: TypePaymentAuditLog, + typ: TypeChannelMonitorHistory, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -12227,20 +11787,20 @@ func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) return m } -// withPaymentAuditLogID sets the ID field of the mutation. -func withPaymentAuditLogID(id int64) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { +// withChannelMonitorHistoryID sets the ID field of the mutation. +func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption { + return func(m *ChannelMonitorHistoryMutation) { var ( err error once sync.Once - value *PaymentAuditLog + value *ChannelMonitorHistory ) - m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + m.oldValue = func(ctx context.Context) (*ChannelMonitorHistory, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentAuditLog.Get(ctx, id) + value, err = m.Client().ChannelMonitorHistory.Get(ctx, id) } }) return value, err @@ -12249,10 +11809,10 @@ func withPaymentAuditLogID(id int64) paymentauditlogOption { } } -// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. -func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { - m.oldValue = func(context.Context) (*PaymentAuditLog, error) { +// withChannelMonitorHistory sets the old ChannelMonitorHistory of the mutation. +func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistoryOption { + return func(m *ChannelMonitorHistoryMutation) { + m.oldValue = func(context.Context) (*ChannelMonitorHistory, error) { return node, nil } m.id = &node.ID @@ -12261,7 +11821,7 @@ func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentAuditLogMutation) Client() *Client { +func (m ChannelMonitorHistoryMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -12269,7 +11829,7 @@ func (m PaymentAuditLogMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentAuditLogMutation) Tx() (*Tx, error) { +func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -12280,7 +11840,7 @@ func (m PaymentAuditLogMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { +func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -12291,7 +11851,7 @@ func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -12300,201 +11860,381 @@ func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + return m.Client().ChannelMonitorHistory.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetOrderID sets the "order_id" field. -func (m *PaymentAuditLogMutation) SetOrderID(s string) { - m.order_id = &s +// SetMonitorID sets the "monitor_id" field. +func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) { + m.monitor = &i } -// OrderID returns the value of the "order_id" field in the mutation. -func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { - v := m.order_id +// MonitorID returns the value of the "monitor_id" field in the mutation. +func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) { + v := m.monitor if v == nil { return } return *v, true } -// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderID is only allowed on UpdateOne operations") + return v, errors.New("OldMonitorID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderID requires an ID field in the mutation") + return v, errors.New("OldMonitorID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOrderID: %w", err) + return v, fmt.Errorf("querying old value for OldMonitorID: %w", err) } - return oldValue.OrderID, nil + return oldValue.MonitorID, nil } -// ResetOrderID resets all changes to the "order_id" field. -func (m *PaymentAuditLogMutation) ResetOrderID() { - m.order_id = nil +// ResetMonitorID resets all changes to the "monitor_id" field. +func (m *ChannelMonitorHistoryMutation) ResetMonitorID() { + m.monitor = nil } -// SetAction sets the "action" field. -func (m *PaymentAuditLogMutation) SetAction(s string) { - m.action = &s +// SetModel sets the "model" field. +func (m *ChannelMonitorHistoryMutation) SetModel(s string) { + m.model = &s } -// Action returns the value of the "action" field in the mutation. -func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { - v := m.action +// Model returns the value of the "model" field in the mutation. +func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) { + v := m.model if v == nil { return } return *v, true } -// OldAction returns the old "action" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldModel returns the old "model" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAction is only allowed on UpdateOne operations") + return v, errors.New("OldModel is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAction requires an ID field in the mutation") + return v, errors.New("OldModel requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAction: %w", err) + return v, fmt.Errorf("querying old value for OldModel: %w", err) } - return oldValue.Action, nil + return oldValue.Model, nil } -// ResetAction resets all changes to the "action" field. -func (m *PaymentAuditLogMutation) ResetAction() { - m.action = nil +// ResetModel resets all changes to the "model" field. +func (m *ChannelMonitorHistoryMutation) ResetModel() { + m.model = nil } -// SetDetail sets the "detail" field. -func (m *PaymentAuditLogMutation) SetDetail(s string) { - m.detail = &s +// SetStatus sets the "status" field. +func (m *ChannelMonitorHistoryMutation) SetStatus(c channelmonitorhistory.Status) { + m.status = &c } -// Detail returns the value of the "detail" field in the mutation. -func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { - v := m.detail +// Status returns the value of the "status" field in the mutation. +func (m *ChannelMonitorHistoryMutation) Status() (r channelmonitorhistory.Status, exists bool) { + v := m.status if v == nil { return } return *v, true } -// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldStatus returns the old "status" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorHistoryMutation) OldStatus(ctx context.Context) (v channelmonitorhistory.Status, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDetail is only allowed on UpdateOne operations") + return v, errors.New("OldStatus is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDetail requires an ID field in the mutation") + return v, errors.New("OldStatus requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDetail: %w", err) + return v, fmt.Errorf("querying old value for OldStatus: %w", err) } - return oldValue.Detail, nil + return oldValue.Status, nil } -// ResetDetail resets all changes to the "detail" field. -func (m *PaymentAuditLogMutation) ResetDetail() { - m.detail = nil +// ResetStatus resets all changes to the "status" field. +func (m *ChannelMonitorHistoryMutation) ResetStatus() { + m.status = nil } -// SetOperator sets the "operator" field. -func (m *PaymentAuditLogMutation) SetOperator(s string) { - m.operator = &s +// SetLatencyMs sets the "latency_ms" field. +func (m *ChannelMonitorHistoryMutation) SetLatencyMs(i int) { + m.latency_ms = &i + m.addlatency_ms = nil } -// Operator returns the value of the "operator" field in the mutation. -func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { - v := m.operator +// LatencyMs returns the value of the "latency_ms" field in the mutation. +func (m *ChannelMonitorHistoryMutation) LatencyMs() (r int, exists bool) { + v := m.latency_ms if v == nil { return } return *v, true } -// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldLatencyMs returns the old "latency_ms" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorHistoryMutation) OldLatencyMs(ctx context.Context) (v *int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOperator is only allowed on UpdateOne operations") + return v, errors.New("OldLatencyMs is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOperator requires an ID field in the mutation") + return v, errors.New("OldLatencyMs requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOperator: %w", err) + return v, fmt.Errorf("querying old value for OldLatencyMs: %w", err) } - return oldValue.Operator, nil + return oldValue.LatencyMs, nil } -// ResetOperator resets all changes to the "operator" field. -func (m *PaymentAuditLogMutation) ResetOperator() { - m.operator = nil +// AddLatencyMs adds i to the "latency_ms" field. +func (m *ChannelMonitorHistoryMutation) AddLatencyMs(i int) { + if m.addlatency_ms != nil { + *m.addlatency_ms += i + } else { + m.addlatency_ms = &i + } } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// AddedLatencyMs returns the value that was added to the "latency_ms" field in this mutation. +func (m *ChannelMonitorHistoryMutation) AddedLatencyMs() (r int, exists bool) { + v := m.addlatency_ms + if v == nil { + return + } + return *v, true } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// ClearLatencyMs clears the value of the "latency_ms" field. +func (m *ChannelMonitorHistoryMutation) ClearLatencyMs() { + m.latency_ms = nil + m.addlatency_ms = nil + m.clearedFields[channelmonitorhistory.FieldLatencyMs] = struct{}{} +} + +// LatencyMsCleared returns if the "latency_ms" field was cleared in this mutation. +func (m *ChannelMonitorHistoryMutation) LatencyMsCleared() bool { + _, ok := m.clearedFields[channelmonitorhistory.FieldLatencyMs] + return ok +} + +// ResetLatencyMs resets all changes to the "latency_ms" field. +func (m *ChannelMonitorHistoryMutation) ResetLatencyMs() { + m.latency_ms = nil + m.addlatency_ms = nil + delete(m.clearedFields, channelmonitorhistory.FieldLatencyMs) +} + +// SetPingLatencyMs sets the "ping_latency_ms" field. +func (m *ChannelMonitorHistoryMutation) SetPingLatencyMs(i int) { + m.ping_latency_ms = &i + m.addping_latency_ms = nil +} + +// PingLatencyMs returns the value of the "ping_latency_ms" field in the mutation. +func (m *ChannelMonitorHistoryMutation) PingLatencyMs() (r int, exists bool) { + v := m.ping_latency_ms if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldPingLatencyMs returns the old "ping_latency_ms" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ChannelMonitorHistoryMutation) OldPingLatencyMs(ctx context.Context) (v *int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldPingLatencyMs is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldPingLatencyMs requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldPingLatencyMs: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.PingLatencyMs, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentAuditLogMutation) ResetCreatedAt() { - m.created_at = nil +// AddPingLatencyMs adds i to the "ping_latency_ms" field. +func (m *ChannelMonitorHistoryMutation) AddPingLatencyMs(i int) { + if m.addping_latency_ms != nil { + *m.addping_latency_ms += i + } else { + m.addping_latency_ms = &i + } } -// Where appends a list predicates to the PaymentAuditLogMutation builder. -func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { +// AddedPingLatencyMs returns the value that was added to the "ping_latency_ms" field in this mutation. +func (m *ChannelMonitorHistoryMutation) AddedPingLatencyMs() (r int, exists bool) { + v := m.addping_latency_ms + if v == nil { + return + } + return *v, true +} + +// ClearPingLatencyMs clears the value of the "ping_latency_ms" field. +func (m *ChannelMonitorHistoryMutation) ClearPingLatencyMs() { + m.ping_latency_ms = nil + m.addping_latency_ms = nil + m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] = struct{}{} +} + +// PingLatencyMsCleared returns if the "ping_latency_ms" field was cleared in this mutation. +func (m *ChannelMonitorHistoryMutation) PingLatencyMsCleared() bool { + _, ok := m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] + return ok +} + +// ResetPingLatencyMs resets all changes to the "ping_latency_ms" field. +func (m *ChannelMonitorHistoryMutation) ResetPingLatencyMs() { + m.ping_latency_ms = nil + m.addping_latency_ms = nil + delete(m.clearedFields, channelmonitorhistory.FieldPingLatencyMs) +} + +// SetMessage sets the "message" field. +func (m *ChannelMonitorHistoryMutation) SetMessage(s string) { + m.message = &s +} + +// Message returns the value of the "message" field in the mutation. +func (m *ChannelMonitorHistoryMutation) Message() (r string, exists bool) { + v := m.message + if v == nil { + return + } + return *v, true +} + +// OldMessage returns the old "message" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorHistoryMutation) OldMessage(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMessage: %w", err) + } + return oldValue.Message, nil +} + +// ClearMessage clears the value of the "message" field. +func (m *ChannelMonitorHistoryMutation) ClearMessage() { + m.message = nil + m.clearedFields[channelmonitorhistory.FieldMessage] = struct{}{} +} + +// MessageCleared returns if the "message" field was cleared in this mutation. +func (m *ChannelMonitorHistoryMutation) MessageCleared() bool { + _, ok := m.clearedFields[channelmonitorhistory.FieldMessage] + return ok +} + +// ResetMessage resets all changes to the "message" field. +func (m *ChannelMonitorHistoryMutation) ResetMessage() { + m.message = nil + delete(m.clearedFields, channelmonitorhistory.FieldMessage) +} + +// SetCheckedAt sets the "checked_at" field. +func (m *ChannelMonitorHistoryMutation) SetCheckedAt(t time.Time) { + m.checked_at = &t +} + +// CheckedAt returns the value of the "checked_at" field in the mutation. +func (m *ChannelMonitorHistoryMutation) CheckedAt() (r time.Time, exists bool) { + v := m.checked_at + if v == nil { + return + } + return *v, true +} + +// OldCheckedAt returns the old "checked_at" field's value of the ChannelMonitorHistory entity. +// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ChannelMonitorHistoryMutation) OldCheckedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCheckedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCheckedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCheckedAt: %w", err) + } + return oldValue.CheckedAt, nil +} + +// ResetCheckedAt resets all changes to the "checked_at" field. +func (m *ChannelMonitorHistoryMutation) ResetCheckedAt() { + m.checked_at = nil +} + +// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity. +func (m *ChannelMonitorHistoryMutation) ClearMonitor() { + m.clearedmonitor = true + m.clearedFields[channelmonitorhistory.FieldMonitorID] = struct{}{} +} + +// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared. +func (m *ChannelMonitorHistoryMutation) MonitorCleared() bool { + return m.clearedmonitor +} + +// MonitorIDs returns the "monitor" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// MonitorID instead. It exists only for internal usage by the builders. +func (m *ChannelMonitorHistoryMutation) MonitorIDs() (ids []int64) { + if id := m.monitor; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetMonitor resets all changes to the "monitor" edge. +func (m *ChannelMonitorHistoryMutation) ResetMonitor() { + m.monitor = nil + m.clearedmonitor = false +} + +// Where appends a list predicates to the ChannelMonitorHistoryMutation builder. +func (m *ChannelMonitorHistoryMutation) Where(ps ...predicate.ChannelMonitorHistory) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, +// WhereP appends storage-level predicates to the ChannelMonitorHistoryMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentAuditLog, len(ps)) +func (m *ChannelMonitorHistoryMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ChannelMonitorHistory, len(ps)) for i := range ps { p[i] = ps[i] } @@ -12502,39 +12242,45 @@ func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *PaymentAuditLogMutation) Op() Op { +func (m *ChannelMonitorHistoryMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *PaymentAuditLogMutation) SetOp(op Op) { +func (m *ChannelMonitorHistoryMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (PaymentAuditLog). -func (m *PaymentAuditLogMutation) Type() string { +// Type returns the node type of this mutation (ChannelMonitorHistory). +func (m *ChannelMonitorHistoryMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *PaymentAuditLogMutation) Fields() []string { - fields := make([]string, 0, 5) - if m.order_id != nil { - fields = append(fields, paymentauditlog.FieldOrderID) +func (m *ChannelMonitorHistoryMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.monitor != nil { + fields = append(fields, channelmonitorhistory.FieldMonitorID) } - if m.action != nil { - fields = append(fields, paymentauditlog.FieldAction) + if m.model != nil { + fields = append(fields, channelmonitorhistory.FieldModel) } - if m.detail != nil { - fields = append(fields, paymentauditlog.FieldDetail) + if m.status != nil { + fields = append(fields, channelmonitorhistory.FieldStatus) } - if m.operator != nil { - fields = append(fields, paymentauditlog.FieldOperator) + if m.latency_ms != nil { + fields = append(fields, channelmonitorhistory.FieldLatencyMs) } - if m.created_at != nil { - fields = append(fields, paymentauditlog.FieldCreatedAt) + if m.ping_latency_ms != nil { + fields = append(fields, channelmonitorhistory.FieldPingLatencyMs) + } + if m.message != nil { + fields = append(fields, channelmonitorhistory.FieldMessage) + } + if m.checked_at != nil { + fields = append(fields, channelmonitorhistory.FieldCheckedAt) } return fields } @@ -12542,18 +12288,22 @@ func (m *PaymentAuditLogMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { +func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) { switch name { - case paymentauditlog.FieldOrderID: - return m.OrderID() - case paymentauditlog.FieldAction: - return m.Action() - case paymentauditlog.FieldDetail: - return m.Detail() - case paymentauditlog.FieldOperator: - return m.Operator() - case paymentauditlog.FieldCreatedAt: - return m.CreatedAt() + case channelmonitorhistory.FieldMonitorID: + return m.MonitorID() + case channelmonitorhistory.FieldModel: + return m.Model() + case channelmonitorhistory.FieldStatus: + return m.Status() + case channelmonitorhistory.FieldLatencyMs: + return m.LatencyMs() + case channelmonitorhistory.FieldPingLatencyMs: + return m.PingLatencyMs() + case channelmonitorhistory.FieldMessage: + return m.Message() + case channelmonitorhistory.FieldCheckedAt: + return m.CheckedAt() } return nil, false } @@ -12561,246 +12311,310 @@ func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentauditlog.FieldOrderID: - return m.OldOrderID(ctx) - case paymentauditlog.FieldAction: - return m.OldAction(ctx) - case paymentauditlog.FieldDetail: - return m.OldDetail(ctx) - case paymentauditlog.FieldOperator: - return m.OldOperator(ctx) - case paymentauditlog.FieldCreatedAt: - return m.OldCreatedAt(ctx) + case channelmonitorhistory.FieldMonitorID: + return m.OldMonitorID(ctx) + case channelmonitorhistory.FieldModel: + return m.OldModel(ctx) + case channelmonitorhistory.FieldStatus: + return m.OldStatus(ctx) + case channelmonitorhistory.FieldLatencyMs: + return m.OldLatencyMs(ctx) + case channelmonitorhistory.FieldPingLatencyMs: + return m.OldPingLatencyMs(ctx) + case channelmonitorhistory.FieldMessage: + return m.OldMessage(ctx) + case channelmonitorhistory.FieldCheckedAt: + return m.OldCheckedAt(ctx) } - return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) + return nil, fmt.Errorf("unknown ChannelMonitorHistory field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { +func (m *ChannelMonitorHistoryMutation) SetField(name string, value ent.Value) error { switch name { - case paymentauditlog.FieldOrderID: - v, ok := value.(string) + case channelmonitorhistory.FieldMonitorID: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetOrderID(v) + m.SetMonitorID(v) return nil - case paymentauditlog.FieldAction: + case channelmonitorhistory.FieldModel: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAction(v) + m.SetModel(v) return nil - case paymentauditlog.FieldDetail: - v, ok := value.(string) + case channelmonitorhistory.FieldStatus: + v, ok := value.(channelmonitorhistory.Status) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetDetail(v) + m.SetStatus(v) return nil - case paymentauditlog.FieldOperator: + case channelmonitorhistory.FieldLatencyMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLatencyMs(v) + return nil + case channelmonitorhistory.FieldPingLatencyMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPingLatencyMs(v) + return nil + case channelmonitorhistory.FieldMessage: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetOperator(v) + m.SetMessage(v) return nil - case paymentauditlog.FieldCreatedAt: + case channelmonitorhistory.FieldCheckedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetCheckedAt(v) return nil } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) + return fmt.Errorf("unknown ChannelMonitorHistory field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *PaymentAuditLogMutation) AddedFields() []string { - return nil +func (m *ChannelMonitorHistoryMutation) AddedFields() []string { + var fields []string + if m.addlatency_ms != nil { + fields = append(fields, channelmonitorhistory.FieldLatencyMs) + } + if m.addping_latency_ms != nil { + fields = append(fields, channelmonitorhistory.FieldPingLatencyMs) + } + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { +func (m *ChannelMonitorHistoryMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case channelmonitorhistory.FieldLatencyMs: + return m.AddedLatencyMs() + case channelmonitorhistory.FieldPingLatencyMs: + return m.AddedPingLatencyMs() + } return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { +func (m *ChannelMonitorHistoryMutation) AddField(name string, value ent.Value) error { switch name { + case channelmonitorhistory.FieldLatencyMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddLatencyMs(v) + return nil + case channelmonitorhistory.FieldPingLatencyMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPingLatencyMs(v) + return nil } - return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) + return fmt.Errorf("unknown ChannelMonitorHistory numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *PaymentAuditLogMutation) ClearedFields() []string { - return nil +func (m *ChannelMonitorHistoryMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(channelmonitorhistory.FieldLatencyMs) { + fields = append(fields, channelmonitorhistory.FieldLatencyMs) + } + if m.FieldCleared(channelmonitorhistory.FieldPingLatencyMs) { + fields = append(fields, channelmonitorhistory.FieldPingLatencyMs) + } + if m.FieldCleared(channelmonitorhistory.FieldMessage) { + fields = append(fields, channelmonitorhistory.FieldMessage) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { +func (m *ChannelMonitorHistoryMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) +func (m *ChannelMonitorHistoryMutation) ClearField(name string) error { + switch name { + case channelmonitorhistory.FieldLatencyMs: + m.ClearLatencyMs() + return nil + case channelmonitorhistory.FieldPingLatencyMs: + m.ClearPingLatencyMs() + return nil + case channelmonitorhistory.FieldMessage: + m.ClearMessage() + return nil + } + return fmt.Errorf("unknown ChannelMonitorHistory nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetField(name string) error { +func (m *ChannelMonitorHistoryMutation) ResetField(name string) error { switch name { - case paymentauditlog.FieldOrderID: - m.ResetOrderID() + case channelmonitorhistory.FieldMonitorID: + m.ResetMonitorID() return nil - case paymentauditlog.FieldAction: - m.ResetAction() + case channelmonitorhistory.FieldModel: + m.ResetModel() return nil - case paymentauditlog.FieldDetail: - m.ResetDetail() + case channelmonitorhistory.FieldStatus: + m.ResetStatus() return nil - case paymentauditlog.FieldOperator: - m.ResetOperator() + case channelmonitorhistory.FieldLatencyMs: + m.ResetLatencyMs() return nil - case paymentauditlog.FieldCreatedAt: - m.ResetCreatedAt() + case channelmonitorhistory.FieldPingLatencyMs: + m.ResetPingLatencyMs() + return nil + case channelmonitorhistory.FieldMessage: + m.ResetMessage() + return nil + case channelmonitorhistory.FieldCheckedAt: + m.ResetCheckedAt() return nil } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) + return fmt.Errorf("unknown ChannelMonitorHistory field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentAuditLogMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorHistoryMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.monitor != nil { + edges = append(edges, channelmonitorhistory.EdgeMonitor) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { +func (m *ChannelMonitorHistoryMutation) AddedIDs(name string) []ent.Value { + switch name { + case channelmonitorhistory.EdgeMonitor: + if id := m.monitor; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentAuditLogMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorHistoryMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { +func (m *ChannelMonitorHistoryMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentAuditLogMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *ChannelMonitorHistoryMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedmonitor { + edges = append(edges, channelmonitorhistory.EdgeMonitor) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { +func (m *ChannelMonitorHistoryMutation) EdgeCleared(name string) bool { + switch name { + case channelmonitorhistory.EdgeMonitor: + return m.clearedmonitor + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) +func (m *ChannelMonitorHistoryMutation) ClearEdge(name string) error { + switch name { + case channelmonitorhistory.EdgeMonitor: + m.ClearMonitor() + return nil + } + return fmt.Errorf("unknown ChannelMonitorHistory unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog edge %s", name) +func (m *ChannelMonitorHistoryMutation) ResetEdge(name string) error { + switch name { + case channelmonitorhistory.EdgeMonitor: + m.ResetMonitor() + return nil + } + return fmt.Errorf("unknown ChannelMonitorHistory edge %s", name) } -// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. -type PaymentOrderMutation struct { +// ChannelMonitorRequestTemplateMutation represents an operation that mutates the ChannelMonitorRequestTemplate nodes in the graph. +type ChannelMonitorRequestTemplateMutation struct { config - op Op - typ string - id *int64 - user_email *string - user_name *string - user_notes *string - amount *float64 - addamount *float64 - pay_amount *float64 - addpay_amount *float64 - fee_rate *float64 - addfee_rate *float64 - recharge_code *string - out_trade_no *string - payment_type *string - payment_trade_no *string - pay_url *string - qr_code *string - qr_code_img *string - order_type *string - plan_id *int64 - addplan_id *int64 - subscription_group_id *int64 - addsubscription_group_id *int64 - subscription_days *int - addsubscription_days *int - provider_instance_id *string - status *string - refund_amount *float64 - addrefund_amount *float64 - refund_reason *string - refund_at *time.Time - force_refund *bool - refund_requested_at *time.Time - refund_request_reason *string - refund_requested_by *string - expires_at *time.Time - paid_at *time.Time - completed_at *time.Time - failed_at *time.Time - failed_reason *string - client_ip *string - src_host *string - src_url *string - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - user *int64 - cleareduser bool - done bool - oldValue func(context.Context) (*PaymentOrder, error) - predicates []predicate.PaymentOrder + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + provider *channelmonitorrequesttemplate.Provider + description *string + extra_headers *map[string]string + body_override_mode *string + body_override *map[string]interface{} + clearedFields map[string]struct{} + monitors map[int64]struct{} + removedmonitors map[int64]struct{} + clearedmonitors bool + done bool + oldValue func(context.Context) (*ChannelMonitorRequestTemplate, error) + predicates []predicate.ChannelMonitorRequestTemplate } -var _ ent.Mutation = (*PaymentOrderMutation)(nil) +var _ ent.Mutation = (*ChannelMonitorRequestTemplateMutation)(nil) -// paymentorderOption allows management of the mutation configuration using functional options. -type paymentorderOption func(*PaymentOrderMutation) +// channelmonitorrequesttemplateOption allows management of the mutation configuration using functional options. +type channelmonitorrequesttemplateOption func(*ChannelMonitorRequestTemplateMutation) -// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. -func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { - m := &PaymentOrderMutation{ +// newChannelMonitorRequestTemplateMutation creates new mutation for the ChannelMonitorRequestTemplate entity. +func newChannelMonitorRequestTemplateMutation(c config, op Op, opts ...channelmonitorrequesttemplateOption) *ChannelMonitorRequestTemplateMutation { + m := &ChannelMonitorRequestTemplateMutation{ config: c, op: op, - typ: TypePaymentOrder, + typ: TypeChannelMonitorRequestTemplate, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -12809,20 +12623,20 @@ func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *Payme return m } -// withPaymentOrderID sets the ID field of the mutation. -func withPaymentOrderID(id int64) paymentorderOption { - return func(m *PaymentOrderMutation) { +// withChannelMonitorRequestTemplateID sets the ID field of the mutation. +func withChannelMonitorRequestTemplateID(id int64) channelmonitorrequesttemplateOption { + return func(m *ChannelMonitorRequestTemplateMutation) { var ( err error once sync.Once - value *PaymentOrder + value *ChannelMonitorRequestTemplate ) - m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { + m.oldValue = func(ctx context.Context) (*ChannelMonitorRequestTemplate, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentOrder.Get(ctx, id) + value, err = m.Client().ChannelMonitorRequestTemplate.Get(ctx, id) } }) return value, err @@ -12831,10 +12645,10 @@ func withPaymentOrderID(id int64) paymentorderOption { } } -// withPaymentOrder sets the old PaymentOrder of the mutation. -func withPaymentOrder(node *PaymentOrder) paymentorderOption { - return func(m *PaymentOrderMutation) { - m.oldValue = func(context.Context) (*PaymentOrder, error) { +// withChannelMonitorRequestTemplate sets the old ChannelMonitorRequestTemplate of the mutation. +func withChannelMonitorRequestTemplate(node *ChannelMonitorRequestTemplate) channelmonitorrequesttemplateOption { + return func(m *ChannelMonitorRequestTemplateMutation) { + m.oldValue = func(context.Context) (*ChannelMonitorRequestTemplate, error) { return node, nil } m.id = &node.ID @@ -12843,7 +12657,7 @@ func withPaymentOrder(node *PaymentOrder) paymentorderOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentOrderMutation) Client() *Client { +func (m ChannelMonitorRequestTemplateMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -12851,7 +12665,7 @@ func (m PaymentOrderMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentOrderMutation) Tx() (*Tx, error) { +func (m ChannelMonitorRequestTemplateMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -12862,7 +12676,7 @@ func (m PaymentOrderMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *PaymentOrderMutation) ID() (id int64, exists bool) { +func (m *ChannelMonitorRequestTemplateMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -12873,7 +12687,7 @@ func (m *PaymentOrderMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *ChannelMonitorRequestTemplateMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -12882,1656 +12696,10999 @@ func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) + return m.Client().ChannelMonitorRequestTemplate.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetUserID sets the "user_id" field. -func (m *PaymentOrderMutation) SetUserID(i int64) { - m.user = &i +// SetCreatedAt sets the "created_at" field. +func (m *ChannelMonitorRequestTemplateMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// UserID returns the value of the "user_id" field in the mutation. -func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { - v := m.user +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserID is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserID requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserID: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.UserID, nil + return oldValue.CreatedAt, nil } -// ResetUserID resets all changes to the "user_id" field. -func (m *PaymentOrderMutation) ResetUserID() { - m.user = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetCreatedAt() { + m.created_at = nil } -// SetUserEmail sets the "user_email" field. -func (m *PaymentOrderMutation) SetUserEmail(s string) { - m.user_email = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *ChannelMonitorRequestTemplateMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// UserEmail returns the value of the "user_email" field in the mutation. -func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { - v := m.user_email +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserEmail requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.UserEmail, nil + return oldValue.UpdatedAt, nil } -// ResetUserEmail resets all changes to the "user_email" field. -func (m *PaymentOrderMutation) ResetUserEmail() { - m.user_email = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetUserName sets the "user_name" field. -func (m *PaymentOrderMutation) SetUserName(s string) { - m.user_name = &s +// SetName sets the "name" field. +func (m *ChannelMonitorRequestTemplateMutation) SetName(s string) { + m.name = &s } -// UserName returns the value of the "user_name" field in the mutation. -func (m *PaymentOrderMutation) UserName() (r string, exists bool) { - v := m.user_name +// Name returns the value of the "name" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserName is only allowed on UpdateOne operations") + return v, errors.New("OldName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserName requires an ID field in the mutation") + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserName: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.UserName, nil + return oldValue.Name, nil } -// ResetUserName resets all changes to the "user_name" field. -func (m *PaymentOrderMutation) ResetUserName() { - m.user_name = nil +// ResetName resets all changes to the "name" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetName() { + m.name = nil } -// SetUserNotes sets the "user_notes" field. -func (m *PaymentOrderMutation) SetUserNotes(s string) { - m.user_notes = &s +// SetProvider sets the "provider" field. +func (m *ChannelMonitorRequestTemplateMutation) SetProvider(c channelmonitorrequesttemplate.Provider) { + m.provider = &c } -// UserNotes returns the value of the "user_notes" field in the mutation. -func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { - v := m.user_notes +// Provider returns the value of the "provider" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) Provider() (r channelmonitorrequesttemplate.Provider, exists bool) { + v := m.provider if v == nil { return } return *v, true } -// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldProvider returns the old "provider" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldProvider(ctx context.Context) (v channelmonitorrequesttemplate.Provider, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") + return v, errors.New("OldProvider is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserNotes requires an ID field in the mutation") + return v, errors.New("OldProvider requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) + return v, fmt.Errorf("querying old value for OldProvider: %w", err) } - return oldValue.UserNotes, nil + return oldValue.Provider, nil } -// ClearUserNotes clears the value of the "user_notes" field. -func (m *PaymentOrderMutation) ClearUserNotes() { - m.user_notes = nil - m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} +// ResetProvider resets all changes to the "provider" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() { + m.provider = nil } -// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. -func (m *PaymentOrderMutation) UserNotesCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldUserNotes] - return ok +// SetDescription sets the "description" field. +func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) { + m.description = &s } -// ResetUserNotes resets all changes to the "user_notes" field. -func (m *PaymentOrderMutation) ResetUserNotes() { - m.user_notes = nil - delete(m.clearedFields, paymentorder.FieldUserNotes) +// Description returns the value of the "description" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true } -// SetAmount sets the "amount" field. -func (m *PaymentOrderMutation) SetAmount(f float64) { - m.amount = &f - m.addamount = nil -} - -// Amount returns the value of the "amount" field in the mutation. -func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { - v := m.amount - if v == nil { - return - } - return *v, true -} - -// OldAmount returns the old "amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDescription returns the old "description" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldDescription(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAmount is only allowed on UpdateOne operations") + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAmount requires an ID field in the mutation") + return v, errors.New("OldDescription requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAmount: %w", err) + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return oldValue.Amount, nil + return oldValue.Description, nil } -// AddAmount adds f to the "amount" field. -func (m *PaymentOrderMutation) AddAmount(f float64) { - if m.addamount != nil { - *m.addamount += f - } else { - m.addamount = &f - } +// ClearDescription clears the value of the "description" field. +func (m *ChannelMonitorRequestTemplateMutation) ClearDescription() { + m.description = nil + m.clearedFields[channelmonitorrequesttemplate.FieldDescription] = struct{}{} } -// AddedAmount returns the value that was added to the "amount" field in this mutation. -func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { - v := m.addamount - if v == nil { - return - } - return *v, true +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldDescription] + return ok } -// ResetAmount resets all changes to the "amount" field. -func (m *PaymentOrderMutation) ResetAmount() { - m.amount = nil - m.addamount = nil +// ResetDescription resets all changes to the "description" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, channelmonitorrequesttemplate.FieldDescription) } -// SetPayAmount sets the "pay_amount" field. -func (m *PaymentOrderMutation) SetPayAmount(f float64) { - m.pay_amount = &f - m.addpay_amount = nil +// SetExtraHeaders sets the "extra_headers" field. +func (m *ChannelMonitorRequestTemplateMutation) SetExtraHeaders(value map[string]string) { + m.extra_headers = &value } -// PayAmount returns the value of the "pay_amount" field in the mutation. -func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { - v := m.pay_amount +// ExtraHeaders returns the value of the "extra_headers" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) ExtraHeaders() (r map[string]string, exists bool) { + v := m.extra_headers if v == nil { return } return *v, true } -// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") + return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayAmount requires an ID field in the mutation") + return v, errors.New("OldExtraHeaders requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) - } - return oldValue.PayAmount, nil -} - -// AddPayAmount adds f to the "pay_amount" field. -func (m *PaymentOrderMutation) AddPayAmount(f float64) { - if m.addpay_amount != nil { - *m.addpay_amount += f - } else { - m.addpay_amount = &f - } -} - -// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { - v := m.addpay_amount - if v == nil { - return + return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err) } - return *v, true + return oldValue.ExtraHeaders, nil } -// ResetPayAmount resets all changes to the "pay_amount" field. -func (m *PaymentOrderMutation) ResetPayAmount() { - m.pay_amount = nil - m.addpay_amount = nil +// ResetExtraHeaders resets all changes to the "extra_headers" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetExtraHeaders() { + m.extra_headers = nil } -// SetFeeRate sets the "fee_rate" field. -func (m *PaymentOrderMutation) SetFeeRate(f float64) { - m.fee_rate = &f - m.addfee_rate = nil +// SetBodyOverrideMode sets the "body_override_mode" field. +func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverrideMode(s string) { + m.body_override_mode = &s } -// FeeRate returns the value of the "fee_rate" field in the mutation. -func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { - v := m.fee_rate +// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideMode() (r string, exists bool) { + v := m.body_override_mode if v == nil { return } return *v, true } -// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") + return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFeeRate requires an ID field in the mutation") + return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) - } - return oldValue.FeeRate, nil -} - -// AddFeeRate adds f to the "fee_rate" field. -func (m *PaymentOrderMutation) AddFeeRate(f float64) { - if m.addfee_rate != nil { - *m.addfee_rate += f - } else { - m.addfee_rate = &f - } -} - -// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. -func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { - v := m.addfee_rate - if v == nil { - return + return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err) } - return *v, true + return oldValue.BodyOverrideMode, nil } -// ResetFeeRate resets all changes to the "fee_rate" field. -func (m *PaymentOrderMutation) ResetFeeRate() { - m.fee_rate = nil - m.addfee_rate = nil +// ResetBodyOverrideMode resets all changes to the "body_override_mode" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverrideMode() { + m.body_override_mode = nil } -// SetRechargeCode sets the "recharge_code" field. -func (m *PaymentOrderMutation) SetRechargeCode(s string) { - m.recharge_code = &s +// SetBodyOverride sets the "body_override" field. +func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverride(value map[string]interface{}) { + m.body_override = &value } -// RechargeCode returns the value of the "recharge_code" field in the mutation. -func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { - v := m.recharge_code +// BodyOverride returns the value of the "body_override" field in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) BodyOverride() (r map[string]interface{}, exists bool) { + v := m.body_override if v == nil { return } return *v, true } -// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitorRequestTemplate entity. +// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { +func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") + return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRechargeCode requires an ID field in the mutation") + return v, errors.New("OldBodyOverride requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) + return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err) } - return oldValue.RechargeCode, nil + return oldValue.BodyOverride, nil } -// ResetRechargeCode resets all changes to the "recharge_code" field. -func (m *PaymentOrderMutation) ResetRechargeCode() { - m.recharge_code = nil +// ClearBodyOverride clears the value of the "body_override" field. +func (m *ChannelMonitorRequestTemplateMutation) ClearBodyOverride() { + m.body_override = nil + m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride] = struct{}{} } -// SetOutTradeNo sets the "out_trade_no" field. -func (m *PaymentOrderMutation) SetOutTradeNo(s string) { - m.out_trade_no = &s +// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideCleared() bool { + _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride] + return ok } -// OutTradeNo returns the value of the "out_trade_no" field in the mutation. -func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { - v := m.out_trade_no - if v == nil { - return - } - return *v, true +// ResetBodyOverride resets all changes to the "body_override" field. +func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverride() { + m.body_override = nil + delete(m.clearedFields, channelmonitorrequesttemplate.FieldBodyOverride) } -// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOutTradeNo requires an ID field in the mutation") +// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by ids. +func (m *ChannelMonitorRequestTemplateMutation) AddMonitorIDs(ids ...int64) { + if m.monitors == nil { + m.monitors = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) + for i := range ids { + m.monitors[ids[i]] = struct{}{} } - return oldValue.OutTradeNo, nil } -// ResetOutTradeNo resets all changes to the "out_trade_no" field. -func (m *PaymentOrderMutation) ResetOutTradeNo() { - m.out_trade_no = nil +// ClearMonitors clears the "monitors" edge to the ChannelMonitor entity. +func (m *ChannelMonitorRequestTemplateMutation) ClearMonitors() { + m.clearedmonitors = true } -// SetPaymentType sets the "payment_type" field. -func (m *PaymentOrderMutation) SetPaymentType(s string) { - m.payment_type = &s +// MonitorsCleared reports if the "monitors" edge to the ChannelMonitor entity was cleared. +func (m *ChannelMonitorRequestTemplateMutation) MonitorsCleared() bool { + return m.clearedmonitors } -// PaymentType returns the value of the "payment_type" field in the mutation. -func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { - v := m.payment_type - if v == nil { - return +// RemoveMonitorIDs removes the "monitors" edge to the ChannelMonitor entity by IDs. +func (m *ChannelMonitorRequestTemplateMutation) RemoveMonitorIDs(ids ...int64) { + if m.removedmonitors == nil { + m.removedmonitors = make(map[int64]struct{}) + } + for i := range ids { + delete(m.monitors, ids[i]) + m.removedmonitors[ids[i]] = struct{}{} } - return *v, true } -// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) +// RemovedMonitors returns the removed IDs of the "monitors" edge to the ChannelMonitor entity. +func (m *ChannelMonitorRequestTemplateMutation) RemovedMonitorsIDs() (ids []int64) { + for id := range m.removedmonitors { + ids = append(ids, id) } - return oldValue.PaymentType, nil + return } -// ResetPaymentType resets all changes to the "payment_type" field. -func (m *PaymentOrderMutation) ResetPaymentType() { - m.payment_type = nil +// MonitorsIDs returns the "monitors" edge IDs in the mutation. +func (m *ChannelMonitorRequestTemplateMutation) MonitorsIDs() (ids []int64) { + for id := range m.monitors { + ids = append(ids, id) + } + return } -// SetPaymentTradeNo sets the "payment_trade_no" field. -func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { - m.payment_trade_no = &s +// ResetMonitors resets all changes to the "monitors" edge. +func (m *ChannelMonitorRequestTemplateMutation) ResetMonitors() { + m.monitors = nil + m.clearedmonitors = false + m.removedmonitors = nil } -// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. -func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { - v := m.payment_trade_no - if v == nil { - return - } - return *v, true +// Where appends a list predicates to the ChannelMonitorRequestTemplateMutation builder. +func (m *ChannelMonitorRequestTemplateMutation) Where(ps ...predicate.ChannelMonitorRequestTemplate) { + m.predicates = append(m.predicates, ps...) } -// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) +// WhereP appends storage-level predicates to the ChannelMonitorRequestTemplateMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ChannelMonitorRequestTemplateMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ChannelMonitorRequestTemplate, len(ps)) + for i := range ps { + p[i] = ps[i] } - return oldValue.PaymentTradeNo, nil + m.Where(p...) } -// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. -func (m *PaymentOrderMutation) ResetPaymentTradeNo() { - m.payment_trade_no = nil +// Op returns the operation name. +func (m *ChannelMonitorRequestTemplateMutation) Op() Op { + return m.op } -// SetPayURL sets the "pay_url" field. -func (m *PaymentOrderMutation) SetPayURL(s string) { - m.pay_url = &s +// SetOp allows setting the mutation operation. +func (m *ChannelMonitorRequestTemplateMutation) SetOp(op Op) { + m.op = op } -// PayURL returns the value of the "pay_url" field in the mutation. -func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { - v := m.pay_url - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (ChannelMonitorRequestTemplate). +func (m *ChannelMonitorRequestTemplateMutation) Type() string { + return m.typ } -// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayURL is only allowed on UpdateOne operations") +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ChannelMonitorRequestTemplateMutation) Fields() []string { + fields := make([]string, 0, 8) + if m.created_at != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayURL requires an ID field in the mutation") + if m.updated_at != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldUpdatedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPayURL: %w", err) + if m.name != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldName) } - return oldValue.PayURL, nil + if m.provider != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldProvider) + } + if m.description != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldDescription) + } + if m.extra_headers != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldExtraHeaders) + } + if m.body_override_mode != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverrideMode) + } + if m.body_override != nil { + fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride) + } + return fields } -// ClearPayURL clears the value of the "pay_url" field. -func (m *PaymentOrderMutation) ClearPayURL() { - m.pay_url = nil - m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, bool) { + switch name { + case channelmonitorrequesttemplate.FieldCreatedAt: + return m.CreatedAt() + case channelmonitorrequesttemplate.FieldUpdatedAt: + return m.UpdatedAt() + case channelmonitorrequesttemplate.FieldName: + return m.Name() + case channelmonitorrequesttemplate.FieldProvider: + return m.Provider() + case channelmonitorrequesttemplate.FieldDescription: + return m.Description() + case channelmonitorrequesttemplate.FieldExtraHeaders: + return m.ExtraHeaders() + case channelmonitorrequesttemplate.FieldBodyOverrideMode: + return m.BodyOverrideMode() + case channelmonitorrequesttemplate.FieldBodyOverride: + return m.BodyOverride() + } + return nil, false } -// PayURLCleared returns if the "pay_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) PayURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPayURL] - return ok +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case channelmonitorrequesttemplate.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case channelmonitorrequesttemplate.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case channelmonitorrequesttemplate.FieldName: + return m.OldName(ctx) + case channelmonitorrequesttemplate.FieldProvider: + return m.OldProvider(ctx) + case channelmonitorrequesttemplate.FieldDescription: + return m.OldDescription(ctx) + case channelmonitorrequesttemplate.FieldExtraHeaders: + return m.OldExtraHeaders(ctx) + case channelmonitorrequesttemplate.FieldBodyOverrideMode: + return m.OldBodyOverrideMode(ctx) + case channelmonitorrequesttemplate.FieldBodyOverride: + return m.OldBodyOverride(ctx) + } + return nil, fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name) } -// ResetPayURL resets all changes to the "pay_url" field. -func (m *PaymentOrderMutation) ResetPayURL() { - m.pay_url = nil - delete(m.clearedFields, paymentorder.FieldPayURL) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.Value) error { + switch name { + case channelmonitorrequesttemplate.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case channelmonitorrequesttemplate.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case channelmonitorrequesttemplate.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case channelmonitorrequesttemplate.FieldProvider: + v, ok := value.(channelmonitorrequesttemplate.Provider) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProvider(v) + return nil + case channelmonitorrequesttemplate.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + case channelmonitorrequesttemplate.FieldExtraHeaders: + v, ok := value.(map[string]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExtraHeaders(v) + return nil + case channelmonitorrequesttemplate.FieldBodyOverrideMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBodyOverrideMode(v) + return nil + case channelmonitorrequesttemplate.FieldBodyOverride: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBodyOverride(v) + return nil + } + return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name) } -// SetQrCode sets the "qr_code" field. -func (m *PaymentOrderMutation) SetQrCode(s string) { - m.qr_code = &s +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ChannelMonitorRequestTemplateMutation) AddedFields() []string { + return nil } -// QrCode returns the value of the "qr_code" field in the mutation. -func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { - v := m.qr_code - if v == nil { - return - } - return *v, true +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ChannelMonitorRequestTemplateMutation) AddedField(name string) (ent.Value, bool) { + return nil, false } -// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCode requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldQrCode: %w", err) +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ChannelMonitorRequestTemplateMutation) AddField(name string, value ent.Value) error { + switch name { } - return oldValue.QrCode, nil + return fmt.Errorf("unknown ChannelMonitorRequestTemplate numeric field %s", name) } -// ClearQrCode clears the value of the "qr_code" field. -func (m *PaymentOrderMutation) ClearQrCode() { - m.qr_code = nil - m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ChannelMonitorRequestTemplateMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(channelmonitorrequesttemplate.FieldDescription) { + fields = append(fields, channelmonitorrequesttemplate.FieldDescription) + } + if m.FieldCleared(channelmonitorrequesttemplate.FieldBodyOverride) { + fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride) + } + return fields } -// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCode] +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] return ok } -// ResetQrCode resets all changes to the "qr_code" field. -func (m *PaymentOrderMutation) ResetQrCode() { - m.qr_code = nil - delete(m.clearedFields, paymentorder.FieldQrCode) +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ChannelMonitorRequestTemplateMutation) ClearField(name string) error { + switch name { + case channelmonitorrequesttemplate.FieldDescription: + m.ClearDescription() + return nil + case channelmonitorrequesttemplate.FieldBodyOverride: + m.ClearBodyOverride() + return nil + } + return fmt.Errorf("unknown ChannelMonitorRequestTemplate nullable field %s", name) } -// SetQrCodeImg sets the "qr_code_img" field. -func (m *PaymentOrderMutation) SetQrCodeImg(s string) { - m.qr_code_img = &s +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error { + switch name { + case channelmonitorrequesttemplate.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case channelmonitorrequesttemplate.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case channelmonitorrequesttemplate.FieldName: + m.ResetName() + return nil + case channelmonitorrequesttemplate.FieldProvider: + m.ResetProvider() + return nil + case channelmonitorrequesttemplate.FieldDescription: + m.ResetDescription() + return nil + case channelmonitorrequesttemplate.FieldExtraHeaders: + m.ResetExtraHeaders() + return nil + case channelmonitorrequesttemplate.FieldBodyOverrideMode: + m.ResetBodyOverrideMode() + return nil + case channelmonitorrequesttemplate.FieldBodyOverride: + m.ResetBodyOverride() + return nil + } + return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name) } -// QrCodeImg returns the value of the "qr_code_img" field in the mutation. -func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { - v := m.qr_code_img - if v == nil { - return +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.monitors != nil { + edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors) } - return *v, true + return edges } -// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCodeImg requires an ID field in the mutation") +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) AddedIDs(name string) []ent.Value { + switch name { + case channelmonitorrequesttemplate.EdgeMonitors: + ids := make([]ent.Value, 0, len(m.monitors)) + for id := range m.monitors { + ids = append(ids, id) + } + return ids } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedmonitors != nil { + edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors) } - return oldValue.QrCodeImg, nil + return edges } -// ClearQrCodeImg clears the value of the "qr_code_img" field. -func (m *PaymentOrderMutation) ClearQrCodeImg() { - m.qr_code_img = nil - m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) RemovedIDs(name string) []ent.Value { + switch name { + case channelmonitorrequesttemplate.EdgeMonitors: + ids := make([]ent.Value, 0, len(m.removedmonitors)) + for id := range m.removedmonitors { + ids = append(ids, id) + } + return ids + } + return nil } -// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeImgCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] - return ok +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedmonitors { + edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors) + } + return edges } -// ResetQrCodeImg resets all changes to the "qr_code_img" field. -func (m *PaymentOrderMutation) ResetQrCodeImg() { - m.qr_code_img = nil - delete(m.clearedFields, paymentorder.FieldQrCodeImg) +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ChannelMonitorRequestTemplateMutation) EdgeCleared(name string) bool { + switch name { + case channelmonitorrequesttemplate.EdgeMonitors: + return m.clearedmonitors + } + return false } -// SetOrderType sets the "order_type" field. -func (m *PaymentOrderMutation) SetOrderType(s string) { - m.order_type = &s +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ChannelMonitorRequestTemplateMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown ChannelMonitorRequestTemplate unique edge %s", name) } -// OrderType returns the value of the "order_type" field in the mutation. -func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { - v := m.order_type - if v == nil { - return +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ChannelMonitorRequestTemplateMutation) ResetEdge(name string) error { + switch name { + case channelmonitorrequesttemplate.EdgeMonitors: + m.ResetMonitors() + return nil } - return *v, true + return fmt.Errorf("unknown ChannelMonitorRequestTemplate edge %s", name) } -// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldOrderType: %w", err) - } - return oldValue.OrderType, nil +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + skip_monitoring *bool + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule } -// ResetOrderType resets all changes to the "order_type" field. -func (m *PaymentOrderMutation) ResetOrderType() { - m.order_type = nil -} +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) -// SetPlanID sets the "plan_id" field. -func (m *PaymentOrderMutation) SetPlanID(i int64) { - m.plan_id = &i - m.addplan_id = nil -} +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) -// PlanID returns the value of the "plan_id" field in the mutation. -func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { - v := m.plan_id - if v == nil { - return +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), } - return *v, true + for _, opt := range opts { + opt(m) + } + return m } -// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlanID is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlanID requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPlanID: %w", err) +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule + ) + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - return oldValue.PlanID, nil } -// AddPlanID adds i to the "plan_id" field. -func (m *PaymentOrderMutation) AddPlanID(i int64) { - if m.addplan_id != nil { - *m.addplan_id += i - } else { - m.addplan_id = &i +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { + return node, nil + } + m.id = &node.ID } } -// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. -func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { - v := m.addplan_id - if v == nil { - return - } - return *v, true +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ErrorPassthroughRuleMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// ClearPlanID clears the value of the "plan_id" field. -func (m *PaymentOrderMutation) ClearPlanID() { - m.plan_id = nil - m.addplan_id = nil - m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) PlanIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPlanID] - return ok +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// ResetPlanID resets all changes to the "plan_id" field. -func (m *PaymentOrderMutation) ResetPlanID() { - m.plan_id = nil - m.addplan_id = nil - delete(m.clearedFields, paymentorder.FieldPlanID) +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } } -// SetSubscriptionGroupID sets the "subscription_group_id" field. -func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { - m.subscription_group_id = &i - m.addsubscription_group_id = nil +// SetCreatedAt sets the "created_at" field. +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { - v := m.subscription_group_id +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) - } - return oldValue.SubscriptionGroupID, nil -} - -// AddSubscriptionGroupID adds i to the "subscription_group_id" field. -func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { - if m.addsubscription_group_id != nil { - *m.addsubscription_group_id += i - } else { - m.addsubscription_group_id = &i - } -} - -// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { - v := m.addsubscription_group_id - if v == nil { - return + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return *v, true -} - -// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. -func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} -} - -// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] - return ok + return oldValue.CreatedAt, nil } -// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. -func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { + m.created_at = nil } -// SetSubscriptionDays sets the "subscription_days" field. -func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { - m.subscription_days = &i - m.addsubscription_days = nil +// SetUpdatedAt sets the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// SubscriptionDays returns the value of the "subscription_days" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { - v := m.subscription_days +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) - } - return oldValue.SubscriptionDays, nil -} - -// AddSubscriptionDays adds i to the "subscription_days" field. -func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { - if m.addsubscription_days != nil { - *m.addsubscription_days += i - } else { - m.addsubscription_days = &i - } -} - -// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { - v := m.addsubscription_days - if v == nil { - return + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return *v, true -} - -// ClearSubscriptionDays clears the value of the "subscription_days" field. -func (m *PaymentOrderMutation) ClearSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} -} - -// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] - return ok + return oldValue.UpdatedAt, nil } -// ResetSubscriptionDays resets all changes to the "subscription_days" field. -func (m *PaymentOrderMutation) ResetSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetProviderInstanceID sets the "provider_instance_id" field. -func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { - m.provider_instance_id = &s +// SetName sets the "name" field. +func (m *ErrorPassthroughRuleMutation) SetName(s string) { + m.name = &s } -// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. -func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { - v := m.provider_instance_id +// Name returns the value of the "name" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") + return v, errors.New("OldName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.ProviderInstanceID, nil -} - -// ClearProviderInstanceID clears the value of the "provider_instance_id" field. -func (m *PaymentOrderMutation) ClearProviderInstanceID() { - m.provider_instance_id = nil - m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} -} - -// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] - return ok + return oldValue.Name, nil } -// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. -func (m *PaymentOrderMutation) ResetProviderInstanceID() { - m.provider_instance_id = nil - delete(m.clearedFields, paymentorder.FieldProviderInstanceID) +// ResetName resets all changes to the "name" field. +func (m *ErrorPassthroughRuleMutation) ResetName() { + m.name = nil } -// SetStatus sets the "status" field. -func (m *PaymentOrderMutation) SetStatus(s string) { - m.status = &s +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b } -// Status returns the value of the "status" field in the mutation. -func (m *PaymentOrderMutation) Status() (r string, exists bool) { - v := m.status +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled if v == nil { return } return *v, true } -// OldStatus returns the old "status" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + return v, errors.New("OldEnabled requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) } - return oldValue.Status, nil + return oldValue.Enabled, nil } -// ResetStatus resets all changes to the "status" field. -func (m *PaymentOrderMutation) ResetStatus() { - m.status = nil +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil } -// SetRefundAmount sets the "refund_amount" field. -func (m *PaymentOrderMutation) SetRefundAmount(f float64) { - m.refund_amount = &f - m.addrefund_amount = nil +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil } -// RefundAmount returns the value of the "refund_amount" field in the mutation. -func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { - v := m.refund_amount +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority if v == nil { return } return *v, true } -// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") + return v, errors.New("OldPriority is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAmount requires an ID field in the mutation") + return v, errors.New("OldPriority requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + return v, fmt.Errorf("querying old value for OldPriority: %w", err) } - return oldValue.RefundAmount, nil + return oldValue.Priority, nil } -// AddRefundAmount adds f to the "refund_amount" field. -func (m *PaymentOrderMutation) AddRefundAmount(f float64) { - if m.addrefund_amount != nil { - *m.addrefund_amount += f +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i } else { - m.addrefund_amount = &f + m.addpriority = &i } } -// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { - v := m.addrefund_amount +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority if v == nil { return } return *v, true } -// ResetRefundAmount resets all changes to the "refund_amount" field. -func (m *PaymentOrderMutation) ResetRefundAmount() { - m.refund_amount = nil - m.addrefund_amount = nil +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil } -// SetRefundReason sets the "refund_reason" field. -func (m *PaymentOrderMutation) SetRefundReason(s string) { - m.refund_reason = &s +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil } -// RefundReason returns the value of the "refund_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { - v := m.refund_reason +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes if v == nil { return } return *v, true } -// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundReason requires an ID field in the mutation") + return v, errors.New("OldErrorCodes requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) } - return oldValue.RefundReason, nil + return oldValue.ErrorCodes, nil } -// ClearRefundReason clears the value of the "refund_reason" field. -func (m *PaymentOrderMutation) ClearRefundReason() { - m.refund_reason = nil - m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) } -// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundReason] - return ok +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true } -// ResetRefundReason resets all changes to the "refund_reason" field. -func (m *PaymentOrderMutation) ResetRefundReason() { - m.refund_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundReason) +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} } -// SetRefundAt sets the "refund_at" field. -func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { - m.refund_at = &t +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok } -// RefundAt returns the value of the "refund_at" field in the mutation. -func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { - v := m.refund_at - if v == nil { - return - } - return *v, true +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) } -// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil +} + +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords + if v == nil { + return + } + return *v, true +} + +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAt requires an ID field in the mutation") + return v, errors.New("OldKeywords requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) } - return oldValue.RefundAt, nil + return oldValue.Keywords, nil } -// ClearRefundAt clears the value of the "refund_at" field. -func (m *PaymentOrderMutation) ClearRefundAt() { - m.refund_at = nil - m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) } -// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundAt] +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true +} + +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] return ok } -// ResetRefundAt resets all changes to the "refund_at" field. -func (m *PaymentOrderMutation) ResetRefundAt() { - m.refund_at = nil - delete(m.clearedFields, paymentorder.FieldRefundAt) +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) } -// SetForceRefund sets the "force_refund" field. -func (m *PaymentOrderMutation) SetForceRefund(b bool) { - m.force_refund = &b +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s } -// ForceRefund returns the value of the "force_refund" field in the mutation. -func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { - v := m.force_refund +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode if v == nil { return } return *v, true } -// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldForceRefund requires an ID field in the mutation") + return v, errors.New("OldMatchMode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) } - return oldValue.ForceRefund, nil + return oldValue.MatchMode, nil } -// ResetForceRefund resets all changes to the "force_refund" field. -func (m *PaymentOrderMutation) ResetForceRefund() { - m.force_refund = nil +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil } -// SetRefundRequestedAt sets the "refund_requested_at" field. -func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { - m.refund_requested_at = &t +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil } -// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { - v := m.refund_requested_at +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms if v == nil { return } return *v, true } -// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + return v, errors.New("OldPlatforms requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) } - return oldValue.RefundRequestedAt, nil + return oldValue.Platforms, nil } -// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. -func (m *PaymentOrderMutation) ClearRefundRequestedAt() { - m.refund_requested_at = nil - m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) } -// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false + } + return m.appendplatforms, true +} + +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +} + +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] return ok } -// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. -func (m *PaymentOrderMutation) ResetRefundRequestedAt() { - m.refund_requested_at = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) } -// SetRefundRequestReason sets the "refund_request_reason" field. -func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { - m.refund_request_reason = &s +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b } -// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { - v := m.refund_request_reason +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code if v == nil { return } return *v, true } -// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) } - return oldValue.RefundRequestReason, nil -} - -// ClearRefundRequestReason clears the value of the "refund_request_reason" field. -func (m *PaymentOrderMutation) ClearRefundRequestReason() { - m.refund_request_reason = nil - m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} -} - -// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] - return ok + return oldValue.PassthroughCode, nil } -// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. -func (m *PaymentOrderMutation) ResetRefundRequestReason() { - m.refund_request_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil } -// SetRefundRequestedBy sets the "refund_requested_by" field. -func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { - m.refund_requested_by = &s +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil } -// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { - v := m.refund_requested_by +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code if v == nil { return } return *v, true } -// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + return v, errors.New("OldResponseCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) } - return oldValue.RefundRequestedBy, nil + return oldValue.ResponseCode, nil } -// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. -func (m *PaymentOrderMutation) ClearRefundRequestedBy() { - m.refund_requested_by = nil - m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i + } else { + m.addresponse_code = &i + } } -// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code + if v == nil { + return + } + return *v, true +} + +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +} + +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] return ok } -// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. -func (m *PaymentOrderMutation) ResetRefundRequestedBy() { - m.refund_requested_by = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) } -// SetExpiresAt sets the "expires_at" field. -func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body if v == nil { return } return *v, true } -// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) } - return oldValue.ExpiresAt, nil + return oldValue.PassthroughBody, nil } -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *PaymentOrderMutation) ResetExpiresAt() { - m.expires_at = nil +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil } -// SetPaidAt sets the "paid_at" field. -func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { - m.paid_at = &t +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s } -// PaidAt returns the value of the "paid_at" field in the mutation. -func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { - v := m.paid_at +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message if v == nil { return } return *v, true } -// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaidAt requires an ID field in the mutation") + return v, errors.New("OldCustomMessage requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) } - return oldValue.PaidAt, nil + return oldValue.CustomMessage, nil } -// ClearPaidAt clears the value of the "paid_at" field. -func (m *PaymentOrderMutation) ClearPaidAt() { - m.paid_at = nil - m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} } -// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) PaidAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPaidAt] +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] return ok } -// ResetPaidAt resets all changes to the "paid_at" field. -func (m *PaymentOrderMutation) ResetPaidAt() { - m.paid_at = nil - delete(m.clearedFields, paymentorder.FieldPaidAt) +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) } -// SetCompletedAt sets the "completed_at" field. -func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { - m.completed_at = &t +// SetSkipMonitoring sets the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { + m.skip_monitoring = &b } -// CompletedAt returns the value of the "completed_at" field in the mutation. -func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { - v := m.completed_at +// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. +func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { + v := m.skip_monitoring if v == nil { return } return *v, true } -// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") + return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCompletedAt requires an ID field in the mutation") + return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) + return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) } - return oldValue.CompletedAt, nil -} - -// ClearCompletedAt clears the value of the "completed_at" field. -func (m *PaymentOrderMutation) ClearCompletedAt() { - m.completed_at = nil - m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} -} - -// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) CompletedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldCompletedAt] - return ok + return oldValue.SkipMonitoring, nil } -// ResetCompletedAt resets all changes to the "completed_at" field. -func (m *PaymentOrderMutation) ResetCompletedAt() { - m.completed_at = nil - delete(m.clearedFields, paymentorder.FieldCompletedAt) +// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { + m.skip_monitoring = nil } -// SetFailedAt sets the "failed_at" field. -func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { - m.failed_at = &t +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s } -// FailedAt returns the value of the "failed_at" field in the mutation. -func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { - v := m.failed_at +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description if v == nil { return } return *v, true } -// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedAt requires an ID field in the mutation") + return v, errors.New("OldDescription requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return oldValue.FailedAt, nil + return oldValue.Description, nil } -// ClearFailedAt clears the value of the "failed_at" field. -func (m *PaymentOrderMutation) ClearFailedAt() { - m.failed_at = nil - m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} } -// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedAt] +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] return ok } -// ResetFailedAt resets all changes to the "failed_at" field. -func (m *PaymentOrderMutation) ResetFailedAt() { - m.failed_at = nil - delete(m.clearedFields, paymentorder.FieldFailedAt) -} - -// SetFailedReason sets the "failed_reason" field. -func (m *PaymentOrderMutation) SetFailedReason(s string) { - m.failed_reason = &s +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) } -// FailedReason returns the value of the "failed_reason" field in the mutation. -func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { - v := m.failed_reason - if v == nil { - return - } - return *v, true +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) } -// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedReason requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] } - return oldValue.FailedReason, nil -} - -// ClearFailedReason clears the value of the "failed_reason" field. -func (m *PaymentOrderMutation) ClearFailedReason() { - m.failed_reason = nil - m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} + m.Where(p...) } -// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedReason] - return ok +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op } -// ResetFailedReason resets all changes to the "failed_reason" field. -func (m *PaymentOrderMutation) ResetFailedReason() { - m.failed_reason = nil - delete(m.clearedFields, paymentorder.FieldFailedReason) +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op } -// SetClientIP sets the "client_ip" field. -func (m *PaymentOrderMutation) SetClientIP(s string) { - m.client_ip = &s +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ } -// ClientIP returns the value of the "client_ip" field in the mutation. -func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { - v := m.client_ip +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 15) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) + } + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) + } + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) + } + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + } + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + } + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.skip_monitoring != nil { + fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + } + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldSkipMonitoring: + return m.SkipMonitoring() + case errorpassthroughrule.FieldDescription: + return m.Description() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldSkipMonitoring: + return m.OldSkipMonitoring(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) + } + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldSkipMonitoring: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSkipMonitoring(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldSkipMonitoring: + m.ResetSkipMonitoring() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) +} + +// GroupMutation represents an operation that mutates the Group nodes in the graph. +type GroupMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + sort_order *int + addsort_order *int + allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool + default_mapped_model *string + messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + rpm_limit *int + addrpm_limit *int + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group +} + +var _ ent.Mutation = (*GroupMutation)(nil) + +// groupOption allows management of the mutation configuration using functional options. +type groupOption func(*GroupMutation) + +// newGroupMutation creates new mutation for the Group entity. +func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { + m := &GroupMutation{ + config: c, + op: op, + typ: TypeGroup, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withGroupID sets the ID field of the mutation. +func withGroupID(id int64) groupOption { + return func(m *GroupMutation) { + var ( + err error + once sync.Once + value *Group + ) + m.oldValue = func(ctx context.Context) (*Group, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Group.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withGroup sets the old Group of the mutation. +func withGroup(node *Group) groupOption { + return func(m *GroupMutation) { + m.oldValue = func(context.Context) (*Group, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m GroupMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m GroupMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *GroupMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *GroupMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *GroupMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *GroupMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *GroupMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *GroupMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *GroupMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[group.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *GroupMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[group.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *GroupMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, group.FieldDeletedAt) +} + +// SetName sets the "name" field. +func (m *GroupMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *GroupMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *GroupMutation) ResetName() { + m.name = nil +} + +// SetDescription sets the "description" field. +func (m *GroupMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *GroupMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *GroupMutation) ClearDescription() { + m.description = nil + m.clearedFields[group.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *GroupMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[group.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *GroupMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, group.FieldDescription) +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (m *GroupMutation) SetRateMultiplier(f float64) { + m.rate_multiplier = &f + m.addrate_multiplier = nil +} + +// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. +func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { + v := m.rate_multiplier + if v == nil { + return + } + return *v, true +} + +// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) + } + return oldValue.RateMultiplier, nil +} + +// AddRateMultiplier adds f to the "rate_multiplier" field. +func (m *GroupMutation) AddRateMultiplier(f float64) { + if m.addrate_multiplier != nil { + *m.addrate_multiplier += f + } else { + m.addrate_multiplier = &f + } +} + +// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. +func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { + v := m.addrate_multiplier + if v == nil { + return + } + return *v, true +} + +// ResetRateMultiplier resets all changes to the "rate_multiplier" field. +func (m *GroupMutation) ResetRateMultiplier() { + m.rate_multiplier = nil + m.addrate_multiplier = nil +} + +// SetIsExclusive sets the "is_exclusive" field. +func (m *GroupMutation) SetIsExclusive(b bool) { + m.is_exclusive = &b +} + +// IsExclusive returns the value of the "is_exclusive" field in the mutation. +func (m *GroupMutation) IsExclusive() (r bool, exists bool) { + v := m.is_exclusive + if v == nil { + return + } + return *v, true +} + +// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsExclusive requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) + } + return oldValue.IsExclusive, nil +} + +// ResetIsExclusive resets all changes to the "is_exclusive" field. +func (m *GroupMutation) ResetIsExclusive() { + m.is_exclusive = nil +} + +// SetStatus sets the "status" field. +func (m *GroupMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *GroupMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *GroupMutation) ResetStatus() { + m.status = nil +} + +// SetPlatform sets the "platform" field. +func (m *GroupMutation) SetPlatform(s string) { + m.platform = &s +} + +// Platform returns the value of the "platform" field in the mutation. +func (m *GroupMutation) Platform() (r string, exists bool) { + v := m.platform + if v == nil { + return + } + return *v, true +} + +// OldPlatform returns the old "platform" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatform requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + } + return oldValue.Platform, nil +} + +// ResetPlatform resets all changes to the "platform" field. +func (m *GroupMutation) ResetPlatform() { + m.platform = nil +} + +// SetSubscriptionType sets the "subscription_type" field. +func (m *GroupMutation) SetSubscriptionType(s string) { + m.subscription_type = &s +} + +// SubscriptionType returns the value of the "subscription_type" field in the mutation. +func (m *GroupMutation) SubscriptionType() (r string, exists bool) { + v := m.subscription_type + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) + } + return oldValue.SubscriptionType, nil +} + +// ResetSubscriptionType resets all changes to the "subscription_type" field. +func (m *GroupMutation) ResetSubscriptionType() { + m.subscription_type = nil +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (m *GroupMutation) SetDailyLimitUsd(f float64) { + m.daily_limit_usd = &f + m.adddaily_limit_usd = nil +} + +// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. +func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { + v := m.daily_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + } + return oldValue.DailyLimitUsd, nil +} + +// AddDailyLimitUsd adds f to the "daily_limit_usd" field. +func (m *GroupMutation) AddDailyLimitUsd(f float64) { + if m.adddaily_limit_usd != nil { + *m.adddaily_limit_usd += f + } else { + m.adddaily_limit_usd = &f + } +} + +// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. +func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { + v := m.adddaily_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (m *GroupMutation) ClearDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} +} + +// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) DailyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldDailyLimitUsd] + return ok +} + +// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. +func (m *GroupMutation) ResetDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + delete(m.clearedFields, group.FieldDailyLimitUsd) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { + m.weekly_limit_usd = &f + m.addweekly_limit_usd = nil +} + +// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. +func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { + v := m.weekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) + } + return oldValue.WeeklyLimitUsd, nil +} + +// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. +func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { + if m.addweekly_limit_usd != nil { + *m.addweekly_limit_usd += f + } else { + m.addweekly_limit_usd = &f + } +} + +// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { + v := m.addweekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (m *GroupMutation) ClearWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} +} + +// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) WeeklyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] + return ok +} + +// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. +func (m *GroupMutation) ResetWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + delete(m.clearedFields, group.FieldWeeklyLimitUsd) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { + m.monthly_limit_usd = &f + m.addmonthly_limit_usd = nil +} + +// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. +func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { + v := m.monthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + } + return oldValue.MonthlyLimitUsd, nil +} + +// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. +func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { + if m.addmonthly_limit_usd != nil { + *m.addmonthly_limit_usd += f + } else { + m.addmonthly_limit_usd = &f + } +} + +// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { + v := m.addmonthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (m *GroupMutation) ClearMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +} + +// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) MonthlyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] + return ok +} + +// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. +func (m *GroupMutation) ResetMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + delete(m.clearedFields, group.FieldMonthlyLimitUsd) +} + +// SetDefaultValidityDays sets the "default_validity_days" field. +func (m *GroupMutation) SetDefaultValidityDays(i int) { + m.default_validity_days = &i + m.adddefault_validity_days = nil +} + +// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. +func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { + v := m.default_validity_days + if v == nil { + return + } + return *v, true +} + +// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) + } + return oldValue.DefaultValidityDays, nil +} + +// AddDefaultValidityDays adds i to the "default_validity_days" field. +func (m *GroupMutation) AddDefaultValidityDays(i int) { + if m.adddefault_validity_days != nil { + *m.adddefault_validity_days += i + } else { + m.adddefault_validity_days = &i + } +} + +// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. +func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { + v := m.adddefault_validity_days + if v == nil { + return + } + return *v, true +} + +// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. +func (m *GroupMutation) ResetDefaultValidityDays() { + m.default_validity_days = nil + m.adddefault_validity_days = nil +} + +// SetImagePrice1k sets the "image_price_1k" field. +func (m *GroupMutation) SetImagePrice1k(f float64) { + m.image_price_1k = &f + m.addimage_price_1k = nil +} + +// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. +func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { + v := m.image_price_1k + if v == nil { + return + } + return *v, true +} + +// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + } + return oldValue.ImagePrice1k, nil +} + +// AddImagePrice1k adds f to the "image_price_1k" field. +func (m *GroupMutation) AddImagePrice1k(f float64) { + if m.addimage_price_1k != nil { + *m.addimage_price_1k += f + } else { + m.addimage_price_1k = &f + } +} + +// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. +func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { + v := m.addimage_price_1k + if v == nil { + return + } + return *v, true +} + +// ClearImagePrice1k clears the value of the "image_price_1k" field. +func (m *GroupMutation) ClearImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + m.clearedFields[group.FieldImagePrice1k] = struct{}{} +} + +// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice1kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice1k] + return ok +} + +// ResetImagePrice1k resets all changes to the "image_price_1k" field. +func (m *GroupMutation) ResetImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + delete(m.clearedFields, group.FieldImagePrice1k) +} + +// SetImagePrice2k sets the "image_price_2k" field. +func (m *GroupMutation) SetImagePrice2k(f float64) { + m.image_price_2k = &f + m.addimage_price_2k = nil +} + +// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. +func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { + v := m.image_price_2k + if v == nil { + return + } + return *v, true +} + +// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice2k requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) + } + return oldValue.ImagePrice2k, nil +} + +// AddImagePrice2k adds f to the "image_price_2k" field. +func (m *GroupMutation) AddImagePrice2k(f float64) { + if m.addimage_price_2k != nil { + *m.addimage_price_2k += f + } else { + m.addimage_price_2k = &f + } +} + +// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. +func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { + v := m.addimage_price_2k + if v == nil { + return + } + return *v, true +} + +// ClearImagePrice2k clears the value of the "image_price_2k" field. +func (m *GroupMutation) ClearImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + m.clearedFields[group.FieldImagePrice2k] = struct{}{} +} + +// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice2kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice2k] + return ok +} + +// ResetImagePrice2k resets all changes to the "image_price_2k" field. +func (m *GroupMutation) ResetImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + delete(m.clearedFields, group.FieldImagePrice2k) +} + +// SetImagePrice4k sets the "image_price_4k" field. +func (m *GroupMutation) SetImagePrice4k(f float64) { + m.image_price_4k = &f + m.addimage_price_4k = nil +} + +// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. +func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { + v := m.image_price_4k + if v == nil { + return + } + return *v, true +} + +// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice4k requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) + } + return oldValue.ImagePrice4k, nil +} + +// AddImagePrice4k adds f to the "image_price_4k" field. +func (m *GroupMutation) AddImagePrice4k(f float64) { + if m.addimage_price_4k != nil { + *m.addimage_price_4k += f + } else { + m.addimage_price_4k = &f + } +} + +// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. +func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { + v := m.addimage_price_4k + if v == nil { + return + } + return *v, true +} + +// ClearImagePrice4k clears the value of the "image_price_4k" field. +func (m *GroupMutation) ClearImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + m.clearedFields[group.FieldImagePrice4k] = struct{}{} +} + +// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice4kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice4k] + return ok +} + +// ResetImagePrice4k resets all changes to the "image_price_4k" field. +func (m *GroupMutation) ResetImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + delete(m.clearedFields, group.FieldImagePrice4k) +} + +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (m *GroupMutation) SetClaudeCodeOnly(b bool) { + m.claude_code_only = &b +} + +// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. +func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { + v := m.claude_code_only + if v == nil { + return + } + return *v, true +} + +// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) + } + return oldValue.ClaudeCodeOnly, nil +} + +// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. +func (m *GroupMutation) ResetClaudeCodeOnly() { + m.claude_code_only = nil +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (m *GroupMutation) SetFallbackGroupID(i int64) { + m.fallback_group_id = &i + m.addfallback_group_id = nil +} + +// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. +func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { + v := m.fallback_group_id + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + } + return oldValue.FallbackGroupID, nil +} + +// AddFallbackGroupID adds i to the "fallback_group_id" field. +func (m *GroupMutation) AddFallbackGroupID(i int64) { + if m.addfallback_group_id != nil { + *m.addfallback_group_id += i + } else { + m.addfallback_group_id = &i + } +} + +// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { + v := m.addfallback_group_id + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (m *GroupMutation) ClearFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + m.clearedFields[group.FieldFallbackGroupID] = struct{}{} +} + +// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupID] + return ok +} + +// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. +func (m *GroupMutation) ResetFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + delete(m.clearedFields, group.FieldFallbackGroupID) +} + +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil +} + +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i + } +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + +// SetModelRouting sets the "model_routing" field. +func (m *GroupMutation) SetModelRouting(value map[string][]int64) { + m.model_routing = &value +} + +// ModelRouting returns the value of the "model_routing" field in the mutation. +func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { + v := m.model_routing + if v == nil { + return + } + return *v, true +} + +// OldModelRouting returns the old "model_routing" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRouting requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) + } + return oldValue.ModelRouting, nil +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (m *GroupMutation) ClearModelRouting() { + m.model_routing = nil + m.clearedFields[group.FieldModelRouting] = struct{}{} +} + +// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. +func (m *GroupMutation) ModelRoutingCleared() bool { + _, ok := m.clearedFields[group.FieldModelRouting] + return ok +} + +// ResetModelRouting resets all changes to the "model_routing" field. +func (m *GroupMutation) ResetModelRouting() { + m.model_routing = nil + delete(m.clearedFields, group.FieldModelRouting) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (m *GroupMutation) SetModelRoutingEnabled(b bool) { + m.model_routing_enabled = &b +} + +// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. +func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { + v := m.model_routing_enabled + if v == nil { + return + } + return *v, true +} + +// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) + } + return oldValue.ModelRoutingEnabled, nil +} + +// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. +func (m *GroupMutation) ResetModelRoutingEnabled() { + m.model_routing_enabled = nil +} + +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return + } + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) + } + return oldValue.McpXMLInject, nil +} + +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return + } + return *v, true +} + +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil +} + +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) +} + +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} + +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} + +// SetSortOrder sets the "sort_order" field. +func (m *GroupMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *GroupMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *GroupMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *GroupMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { + m.allow_messages_dispatch = &b +} + +// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. +func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { + v := m.allow_messages_dispatch + if v == nil { + return + } + return *v, true +} + +// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) + } + return oldValue.AllowMessagesDispatch, nil +} + +// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. +func (m *GroupMutation) ResetAllowMessagesDispatch() { + m.allow_messages_dispatch = nil +} + +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b +} + +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only + if v == nil { + return + } + return *v, true +} + +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + } + return oldValue.RequireOauthOnly, nil +} + +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil +} + +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b +} + +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set + if v == nil { + return + } + return *v, true +} + +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + } + return oldValue.RequirePrivacySet, nil +} + +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil +} + +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (m *GroupMutation) SetDefaultMappedModel(s string) { + m.default_mapped_model = &s +} + +// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. +func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { + v := m.default_mapped_model + if v == nil { + return + } + return *v, true +} + +// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + } + return oldValue.DefaultMappedModel, nil +} + +// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. +func (m *GroupMutation) ResetDefaultMappedModel() { + m.default_mapped_model = nil +} + +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { + m.messages_dispatch_model_config = &damdmc +} + +// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. +func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { + v := m.messages_dispatch_model_config + if v == nil { + return + } + return *v, true +} + +// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) + } + return oldValue.MessagesDispatchModelConfig, nil +} + +// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. +func (m *GroupMutation) ResetMessagesDispatchModelConfig() { + m.messages_dispatch_model_config = nil +} + +// SetRpmLimit sets the "rpm_limit" field. +func (m *GroupMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *GroupMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *GroupMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *GroupMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + +// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. +func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { + if m.api_keys == nil { + m.api_keys = make(map[int64]struct{}) + } + for i := range ids { + m.api_keys[ids[i]] = struct{}{} + } +} + +// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) ClearAPIKeys() { + m.clearedapi_keys = true +} + +// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. +func (m *GroupMutation) APIKeysCleared() bool { + return m.clearedapi_keys +} + +// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. +func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { + if m.removedapi_keys == nil { + m.removedapi_keys = make(map[int64]struct{}) + } + for i := range ids { + delete(m.api_keys, ids[i]) + m.removedapi_keys[ids[i]] = struct{}{} + } +} + +// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { + for id := range m.removedapi_keys { + ids = append(ids, id) + } + return +} + +// APIKeysIDs returns the "api_keys" edge IDs in the mutation. +func (m *GroupMutation) APIKeysIDs() (ids []int64) { + for id := range m.api_keys { + ids = append(ids, id) + } + return +} + +// ResetAPIKeys resets all changes to the "api_keys" edge. +func (m *GroupMutation) ResetAPIKeys() { + m.api_keys = nil + m.clearedapi_keys = false + m.removedapi_keys = nil +} + +// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. +func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { + if m.redeem_codes == nil { + m.redeem_codes = make(map[int64]struct{}) + } + for i := range ids { + m.redeem_codes[ids[i]] = struct{}{} + } +} + +// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) ClearRedeemCodes() { + m.clearedredeem_codes = true +} + +// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. +func (m *GroupMutation) RedeemCodesCleared() bool { + return m.clearedredeem_codes +} + +// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. +func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { + if m.removedredeem_codes == nil { + m.removedredeem_codes = make(map[int64]struct{}) + } + for i := range ids { + delete(m.redeem_codes, ids[i]) + m.removedredeem_codes[ids[i]] = struct{}{} + } +} + +// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { + for id := range m.removedredeem_codes { + ids = append(ids, id) + } + return +} + +// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. +func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { + for id := range m.redeem_codes { + ids = append(ids, id) + } + return +} + +// ResetRedeemCodes resets all changes to the "redeem_codes" edge. +func (m *GroupMutation) ResetRedeemCodes() { + m.redeem_codes = nil + m.clearedredeem_codes = false + m.removedredeem_codes = nil +} + +// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. +func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { + if m.subscriptions == nil { + m.subscriptions = make(map[int64]struct{}) + } + for i := range ids { + m.subscriptions[ids[i]] = struct{}{} + } +} + +// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) ClearSubscriptions() { + m.clearedsubscriptions = true +} + +// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. +func (m *GroupMutation) SubscriptionsCleared() bool { + return m.clearedsubscriptions +} + +// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. +func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { + if m.removedsubscriptions == nil { + m.removedsubscriptions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.subscriptions, ids[i]) + m.removedsubscriptions[ids[i]] = struct{}{} + } +} + +// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { + for id := range m.removedsubscriptions { + ids = append(ids, id) + } + return +} + +// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. +func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { + for id := range m.subscriptions { + ids = append(ids, id) + } + return +} + +// ResetSubscriptions resets all changes to the "subscriptions" edge. +func (m *GroupMutation) ResetSubscriptions() { + m.subscriptions = nil + m.clearedsubscriptions = false + m.removedsubscriptions = nil +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *GroupMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *GroupMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *GroupMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + +// AddAccountIDs adds the "accounts" edge to the Account entity by ids. +func (m *GroupMutation) AddAccountIDs(ids ...int64) { + if m.accounts == nil { + m.accounts = make(map[int64]struct{}) + } + for i := range ids { + m.accounts[ids[i]] = struct{}{} + } +} + +// ClearAccounts clears the "accounts" edge to the Account entity. +func (m *GroupMutation) ClearAccounts() { + m.clearedaccounts = true +} + +// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. +func (m *GroupMutation) AccountsCleared() bool { + return m.clearedaccounts +} + +// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. +func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { + if m.removedaccounts == nil { + m.removedaccounts = make(map[int64]struct{}) + } + for i := range ids { + delete(m.accounts, ids[i]) + m.removedaccounts[ids[i]] = struct{}{} + } +} + +// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. +func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { + for id := range m.removedaccounts { + ids = append(ids, id) + } + return +} + +// AccountsIDs returns the "accounts" edge IDs in the mutation. +func (m *GroupMutation) AccountsIDs() (ids []int64) { + for id := range m.accounts { + ids = append(ids, id) + } + return +} + +// ResetAccounts resets all changes to the "accounts" edge. +func (m *GroupMutation) ResetAccounts() { + m.accounts = nil + m.clearedaccounts = false + m.removedaccounts = nil +} + +// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. +func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { + if m.allowed_users == nil { + m.allowed_users = make(map[int64]struct{}) + } + for i := range ids { + m.allowed_users[ids[i]] = struct{}{} + } +} + +// ClearAllowedUsers clears the "allowed_users" edge to the User entity. +func (m *GroupMutation) ClearAllowedUsers() { + m.clearedallowed_users = true +} + +// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. +func (m *GroupMutation) AllowedUsersCleared() bool { + return m.clearedallowed_users +} + +// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. +func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { + if m.removedallowed_users == nil { + m.removedallowed_users = make(map[int64]struct{}) + } + for i := range ids { + delete(m.allowed_users, ids[i]) + m.removedallowed_users[ids[i]] = struct{}{} + } +} + +// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. +func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { + for id := range m.removedallowed_users { + ids = append(ids, id) + } + return +} + +// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. +func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { + for id := range m.allowed_users { + ids = append(ids, id) + } + return +} + +// ResetAllowedUsers resets all changes to the "allowed_users" edge. +func (m *GroupMutation) ResetAllowedUsers() { + m.allowed_users = nil + m.clearedallowed_users = false + m.removedallowed_users = nil +} + +// Where appends a list predicates to the GroupMutation builder. +func (m *GroupMutation) Where(ps ...predicate.Group) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Group, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *GroupMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *GroupMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (Group). +func (m *GroupMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *GroupMutation) Fields() []string { + fields := make([]string, 0, 31) + if m.created_at != nil { + fields = append(fields, group.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, group.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, group.FieldDeletedAt) + } + if m.name != nil { + fields = append(fields, group.FieldName) + } + if m.description != nil { + fields = append(fields, group.FieldDescription) + } + if m.rate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) + } + if m.is_exclusive != nil { + fields = append(fields, group.FieldIsExclusive) + } + if m.status != nil { + fields = append(fields, group.FieldStatus) + } + if m.platform != nil { + fields = append(fields, group.FieldPlatform) + } + if m.subscription_type != nil { + fields = append(fields, group.FieldSubscriptionType) + } + if m.daily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.weekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.monthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.default_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.image_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.image_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.image_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.claude_code_only != nil { + fields = append(fields, group.FieldClaudeCodeOnly) + } + if m.fallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.model_routing != nil { + fields = append(fields, group.FieldModelRouting) + } + if m.model_routing_enabled != nil { + fields = append(fields, group.FieldModelRoutingEnabled) + } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } + if m.sort_order != nil { + fields = append(fields, group.FieldSortOrder) + } + if m.allow_messages_dispatch != nil { + fields = append(fields, group.FieldAllowMessagesDispatch) + } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } + if m.default_mapped_model != nil { + fields = append(fields, group.FieldDefaultMappedModel) + } + if m.messages_dispatch_model_config != nil { + fields = append(fields, group.FieldMessagesDispatchModelConfig) + } + if m.rpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *GroupMutation) Field(name string) (ent.Value, bool) { + switch name { + case group.FieldCreatedAt: + return m.CreatedAt() + case group.FieldUpdatedAt: + return m.UpdatedAt() + case group.FieldDeletedAt: + return m.DeletedAt() + case group.FieldName: + return m.Name() + case group.FieldDescription: + return m.Description() + case group.FieldRateMultiplier: + return m.RateMultiplier() + case group.FieldIsExclusive: + return m.IsExclusive() + case group.FieldStatus: + return m.Status() + case group.FieldPlatform: + return m.Platform() + case group.FieldSubscriptionType: + return m.SubscriptionType() + case group.FieldDailyLimitUsd: + return m.DailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.WeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.MonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.DefaultValidityDays() + case group.FieldImagePrice1k: + return m.ImagePrice1k() + case group.FieldImagePrice2k: + return m.ImagePrice2k() + case group.FieldImagePrice4k: + return m.ImagePrice4k() + case group.FieldClaudeCodeOnly: + return m.ClaudeCodeOnly() + case group.FieldFallbackGroupID: + return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() + case group.FieldModelRouting: + return m.ModelRouting() + case group.FieldModelRoutingEnabled: + return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() + case group.FieldSortOrder: + return m.SortOrder() + case group.FieldAllowMessagesDispatch: + return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() + case group.FieldDefaultMappedModel: + return m.DefaultMappedModel() + case group.FieldMessagesDispatchModelConfig: + return m.MessagesDispatchModelConfig() + case group.FieldRpmLimit: + return m.RpmLimit() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case group.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case group.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case group.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case group.FieldName: + return m.OldName(ctx) + case group.FieldDescription: + return m.OldDescription(ctx) + case group.FieldRateMultiplier: + return m.OldRateMultiplier(ctx) + case group.FieldIsExclusive: + return m.OldIsExclusive(ctx) + case group.FieldStatus: + return m.OldStatus(ctx) + case group.FieldPlatform: + return m.OldPlatform(ctx) + case group.FieldSubscriptionType: + return m.OldSubscriptionType(ctx) + case group.FieldDailyLimitUsd: + return m.OldDailyLimitUsd(ctx) + case group.FieldWeeklyLimitUsd: + return m.OldWeeklyLimitUsd(ctx) + case group.FieldMonthlyLimitUsd: + return m.OldMonthlyLimitUsd(ctx) + case group.FieldDefaultValidityDays: + return m.OldDefaultValidityDays(ctx) + case group.FieldImagePrice1k: + return m.OldImagePrice1k(ctx) + case group.FieldImagePrice2k: + return m.OldImagePrice2k(ctx) + case group.FieldImagePrice4k: + return m.OldImagePrice4k(ctx) + case group.FieldClaudeCodeOnly: + return m.OldClaudeCodeOnly(ctx) + case group.FieldFallbackGroupID: + return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) + case group.FieldModelRouting: + return m.OldModelRouting(ctx) + case group.FieldModelRoutingEnabled: + return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) + case group.FieldSortOrder: + return m.OldSortOrder(ctx) + case group.FieldAllowMessagesDispatch: + return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) + case group.FieldDefaultMappedModel: + return m.OldDefaultMappedModel(ctx) + case group.FieldMessagesDispatchModelConfig: + return m.OldMessagesDispatchModelConfig(ctx) + case group.FieldRpmLimit: + return m.OldRpmLimit(ctx) + } + return nil, fmt.Errorf("unknown Group field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GroupMutation) SetField(name string, value ent.Value) error { + switch name { + case group.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case group.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case group.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case group.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case group.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + case group.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateMultiplier(v) + return nil + case group.FieldIsExclusive: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsExclusive(v) + return nil + case group.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case group.FieldPlatform: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatform(v) + return nil + case group.FieldSubscriptionType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionType(v) + return nil + case group.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyLimitUsd(v) + return nil + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyLimitUsd(v) + return nil + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyLimitUsd(v) + return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice4k(v) + return nil + case group.FieldClaudeCodeOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaudeCodeOnly(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldModelRouting: + v, ok := value.(map[string][]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRouting(v) + return nil + case group.FieldModelRoutingEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRoutingEnabled(v) + return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) + return nil + case group.FieldAllowMessagesDispatch: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowMessagesDispatch(v) + return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil + case group.FieldDefaultMappedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultMappedModel(v) + return nil + case group.FieldMessagesDispatchModelConfig: + v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMessagesDispatchModelConfig(v) + return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil + } + return fmt.Errorf("unknown Group field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *GroupMutation) AddedFields() []string { + var fields []string + if m.addrate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) + } + if m.adddaily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.addweekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.addmonthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.adddefault_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.addimage_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.addimage_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.addimage_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.addfallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.addsort_order != nil { + fields = append(fields, group.FieldSortOrder) + } + if m.addrpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case group.FieldRateMultiplier: + return m.AddedRateMultiplier() + case group.FieldDailyLimitUsd: + return m.AddedDailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.AddedWeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.AddedMonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.AddedDefaultValidityDays() + case group.FieldImagePrice1k: + return m.AddedImagePrice1k() + case group.FieldImagePrice2k: + return m.AddedImagePrice2k() + case group.FieldImagePrice4k: + return m.AddedImagePrice4k() + case group.FieldFallbackGroupID: + return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() + case group.FieldSortOrder: + return m.AddedSortOrder() + case group.FieldRpmLimit: + return m.AddedRpmLimit() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GroupMutation) AddField(name string, value ent.Value) error { + switch name { + case group.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateMultiplier(v) + return nil + case group.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyLimitUsd(v) + return nil + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyLimitUsd(v) + return nil + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyLimitUsd(v) + return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice4k(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) + return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil + } + return fmt.Errorf("unknown Group numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *GroupMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(group.FieldDeletedAt) { + fields = append(fields, group.FieldDeletedAt) + } + if m.FieldCleared(group.FieldDescription) { + fields = append(fields, group.FieldDescription) + } + if m.FieldCleared(group.FieldDailyLimitUsd) { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.FieldCleared(group.FieldWeeklyLimitUsd) { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.FieldCleared(group.FieldMonthlyLimitUsd) { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.FieldCleared(group.FieldImagePrice1k) { + fields = append(fields, group.FieldImagePrice1k) + } + if m.FieldCleared(group.FieldImagePrice2k) { + fields = append(fields, group.FieldImagePrice2k) + } + if m.FieldCleared(group.FieldImagePrice4k) { + fields = append(fields, group.FieldImagePrice4k) + } + if m.FieldCleared(group.FieldFallbackGroupID) { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.FieldCleared(group.FieldModelRouting) { + fields = append(fields, group.FieldModelRouting) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *GroupMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *GroupMutation) ClearField(name string) error { + switch name { + case group.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case group.FieldDescription: + m.ClearDescription() + return nil + case group.FieldDailyLimitUsd: + m.ClearDailyLimitUsd() + return nil + case group.FieldWeeklyLimitUsd: + m.ClearWeeklyLimitUsd() + return nil + case group.FieldMonthlyLimitUsd: + m.ClearMonthlyLimitUsd() + return nil + case group.FieldImagePrice1k: + m.ClearImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ClearImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ClearImagePrice4k() + return nil + case group.FieldFallbackGroupID: + m.ClearFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ClearModelRouting() + return nil + } + return fmt.Errorf("unknown Group nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *GroupMutation) ResetField(name string) error { + switch name { + case group.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case group.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case group.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case group.FieldName: + m.ResetName() + return nil + case group.FieldDescription: + m.ResetDescription() + return nil + case group.FieldRateMultiplier: + m.ResetRateMultiplier() + return nil + case group.FieldIsExclusive: + m.ResetIsExclusive() + return nil + case group.FieldStatus: + m.ResetStatus() + return nil + case group.FieldPlatform: + m.ResetPlatform() + return nil + case group.FieldSubscriptionType: + m.ResetSubscriptionType() + return nil + case group.FieldDailyLimitUsd: + m.ResetDailyLimitUsd() + return nil + case group.FieldWeeklyLimitUsd: + m.ResetWeeklyLimitUsd() + return nil + case group.FieldMonthlyLimitUsd: + m.ResetMonthlyLimitUsd() + return nil + case group.FieldDefaultValidityDays: + m.ResetDefaultValidityDays() + return nil + case group.FieldImagePrice1k: + m.ResetImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ResetImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ResetImagePrice4k() + return nil + case group.FieldClaudeCodeOnly: + m.ResetClaudeCodeOnly() + return nil + case group.FieldFallbackGroupID: + m.ResetFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ResetModelRouting() + return nil + case group.FieldModelRoutingEnabled: + m.ResetModelRoutingEnabled() + return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil + case group.FieldSortOrder: + m.ResetSortOrder() + return nil + case group.FieldAllowMessagesDispatch: + m.ResetAllowMessagesDispatch() + return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil + case group.FieldDefaultMappedModel: + m.ResetDefaultMappedModel() + return nil + case group.FieldMessagesDispatchModelConfig: + m.ResetMessagesDispatchModelConfig() + return nil + case group.FieldRpmLimit: + m.ResetRpmLimit() + return nil + } + return fmt.Errorf("unknown Group field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *GroupMutation) AddedEdges() []string { + edges := make([]string, 0, 6) + if m.api_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.redeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.subscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.usage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.accounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.allowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.api_keys)) + for id := range m.api_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.redeem_codes)) + for id := range m.redeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.subscriptions)) + for id := range m.subscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.allowed_users)) + for id := range m.allowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *GroupMutation) RemovedEdges() []string { + edges := make([]string, 0, 6) + if m.removedapi_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.removedredeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.removedsubscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.removedaccounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.removedallowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.removedapi_keys)) + for id := range m.removedapi_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.removedredeem_codes)) + for id := range m.removedredeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.removedsubscriptions)) + for id := range m.removedsubscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.removedallowed_users)) + for id := range m.removedallowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *GroupMutation) ClearedEdges() []string { + edges := make([]string, 0, 6) + if m.clearedapi_keys { + edges = append(edges, group.EdgeAPIKeys) + } + if m.clearedredeem_codes { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.clearedsubscriptions { + edges = append(edges, group.EdgeSubscriptions) + } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } + if m.clearedaccounts { + edges = append(edges, group.EdgeAccounts) + } + if m.clearedallowed_users { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeAPIKeys: + return m.clearedapi_keys + case group.EdgeRedeemCodes: + return m.clearedredeem_codes + case group.EdgeSubscriptions: + return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs + case group.EdgeAccounts: + return m.clearedaccounts + case group.EdgeAllowedUsers: + return m.clearedallowed_users + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *GroupMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Group unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeAPIKeys: + m.ResetAPIKeys() + return nil + case group.EdgeRedeemCodes: + m.ResetRedeemCodes() + return nil + case group.EdgeSubscriptions: + m.ResetSubscriptions() + return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil + case group.EdgeAccounts: + m.ResetAccounts() + return nil + case group.EdgeAllowedUsers: + m.ResetAllowedUsers() + return nil + } + return fmt.Errorf("unknown Group edge %s", name) +} + +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord +} + +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) + +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) + +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ + config: c, + op: op, + typ: TypeIdempotencyRecord, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + var ( + err error + once sync.Once + value *IdempotencyRecord + ) + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdempotencyRecord.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdempotencyRecordMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s +} + +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope + if v == nil { + return + } + return *v, true +} + +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScope is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScope requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScope: %w", err) + } + return oldValue.Scope, nil +} + +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s +} + +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash + if v == nil { + return + } + return *v, true +} + +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + } + return oldValue.IdempotencyKeyHash, nil +} + +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s +} + +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint + if v == nil { + return + } + return *v, true +} + +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + } + return oldValue.RequestFingerprint, nil +} + +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil +} + +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil +} + +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil +} + +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return + } + return *v, true +} + +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil +} + +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i + } +} + +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +} + +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok +} + +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +} + +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s +} + +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true +} + +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil +} + +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +} + +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok +} + +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +} + +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s +} + +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true +} + +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +} + +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok +} + +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +} + +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t +} + +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until + if v == nil { + return + } + return *v, true +} + +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + } + return oldValue.LockedUntil, nil +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +} + +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok +} + +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) + } + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) + } + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + } + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + } + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) + } + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) + } + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + +// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph. +type IdentityAdoptionDecisionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + adopt_display_name *bool + adopt_avatar *bool + decided_at *time.Time + clearedFields map[string]struct{} + pending_auth_session *int64 + clearedpending_auth_session bool + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*IdentityAdoptionDecision, error) + predicates []predicate.IdentityAdoptionDecision +} + +var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil) + +// identityadoptiondecisionOption allows management of the mutation configuration using functional options. +type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation) + +// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity. +func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation { + m := &IdentityAdoptionDecisionMutation{ + config: c, + op: op, + typ: TypeIdentityAdoptionDecision, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdentityAdoptionDecisionID sets the ID field of the mutation. +func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + var ( + err error + once sync.Once + value *IdentityAdoptionDecision + ) + m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation. +func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdentityAdoptionDecisionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) { + m.pending_auth_session = &i +} + +// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) { + v := m.pending_auth_session + if v == nil { + return + } + return *v, true +} + +// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err) + } + return oldValue.PendingAuthSessionID, nil +} + +// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() { + m.pending_auth_session = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true +} + +// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() { + m.identity = nil + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool { + _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID] + return ok +} + +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() { + m.identity = nil + delete(m.clearedFields, identityadoptiondecision.FieldIdentityID) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) { + m.adopt_display_name = &b +} + +// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) { + v := m.adopt_display_name + if v == nil { + return + } + return *v, true +} + +// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err) + } + return oldValue.AdoptDisplayName, nil +} + +// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() { + m.adopt_display_name = nil +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) { + m.adopt_avatar = &b +} + +// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) { + v := m.adopt_avatar + if v == nil { + return + } + return *v, true +} + +// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAdoptAvatar requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err) + } + return oldValue.AdoptAvatar, nil +} + +// ResetAdoptAvatar resets all changes to the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() { + m.adopt_avatar = nil +} + +// SetDecidedAt sets the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) { + m.decided_at = &t +} + +// DecidedAt returns the value of the "decided_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) { + v := m.decided_at + if v == nil { + return + } + return *v, true +} + +// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDecidedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err) + } + return oldValue.DecidedAt, nil +} + +// ResetDecidedAt resets all changes to the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() { + m.decided_at = nil +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() { + m.clearedpending_auth_session = true + m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{} +} + +// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool { + return m.clearedpending_auth_session +} + +// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PendingAuthSessionID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) { + if id := m.pending_auth_session; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() { + m.pending_auth_session = nil + m.clearedpending_auth_session = false +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *IdentityAdoptionDecisionMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool { + return m.IdentityIDCleared() || m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *IdentityAdoptionDecisionMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder. +func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdentityAdoptionDecision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdentityAdoptionDecisionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdentityAdoptionDecision). +func (m *IdentityAdoptionDecisionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdentityAdoptionDecisionMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, identityadoptiondecision.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, identityadoptiondecision.FieldUpdatedAt) + } + if m.pending_auth_session != nil { + fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID) + } + if m.identity != nil { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + if m.adopt_display_name != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName) + } + if m.adopt_avatar != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptAvatar) + } + if m.decided_at != nil { + fields = append(fields, identityadoptiondecision.FieldDecidedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.CreatedAt() + case identityadoptiondecision.FieldUpdatedAt: + return m.UpdatedAt() + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.PendingAuthSessionID() + case identityadoptiondecision.FieldIdentityID: + return m.IdentityID() + case identityadoptiondecision.FieldAdoptDisplayName: + return m.AdoptDisplayName() + case identityadoptiondecision.FieldAdoptAvatar: + return m.AdoptAvatar() + case identityadoptiondecision.FieldDecidedAt: + return m.DecidedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case identityadoptiondecision.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.OldPendingAuthSessionID(ctx) + case identityadoptiondecision.FieldIdentityID: + return m.OldIdentityID(ctx) + case identityadoptiondecision.FieldAdoptDisplayName: + return m.OldAdoptDisplayName(ctx) + case identityadoptiondecision.FieldAdoptAvatar: + return m.OldAdoptAvatar(ctx) + case identityadoptiondecision.FieldDecidedAt: + return m.OldDecidedAt(ctx) + } + return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case identityadoptiondecision.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPendingAuthSessionID(v) + return nil + case identityadoptiondecision.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptDisplayName(v) + return nil + case identityadoptiondecision.FieldAdoptAvatar: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptAvatar(v) + return nil + case identityadoptiondecision.FieldDecidedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDecidedAt(v) + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) { + switch name { + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(identityadoptiondecision.FieldIdentityID) { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error { + switch name { + case identityadoptiondecision.FieldIdentityID: + m.ClearIdentityID() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case identityadoptiondecision.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + m.ResetPendingAuthSessionID() + return nil + case identityadoptiondecision.FieldIdentityID: + m.ResetIdentityID() + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + m.ResetAdoptDisplayName() + return nil + case identityadoptiondecision.FieldAdoptAvatar: + m.ResetAdoptAvatar() + return nil + case identityadoptiondecision.FieldDecidedAt: + m.ResetDecidedAt() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.pending_auth_session != nil { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.identity != nil { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + if id := m.pending_auth_session; id != nil { + return []ent.Value{*id} + } + case identityadoptiondecision.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedpending_auth_session { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) + } + if m.clearedidentity { + edges = append(edges, identityadoptiondecision.EdgeIdentity) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + return m.clearedpending_auth_session + case identityadoptiondecision.EdgeIdentity: + return m.clearedidentity + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ClearPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ResetPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name) +} + +// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. +type PaymentAuditLogMutation struct { + config + op Op + typ string + id *int64 + order_id *string + action *string + detail *string + operator *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentAuditLog, error) + predicates []predicate.PaymentAuditLog +} + +var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) + +// paymentauditlogOption allows management of the mutation configuration using functional options. +type paymentauditlogOption func(*PaymentAuditLogMutation) + +// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. +func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { + m := &PaymentAuditLogMutation{ + config: c, + op: op, + typ: TypePaymentAuditLog, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPaymentAuditLogID sets the ID field of the mutation. +func withPaymentAuditLogID(id int64) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + var ( + err error + once sync.Once + value *PaymentAuditLog + ) + m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentAuditLog.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. +func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + m.oldValue = func(context.Context) (*PaymentAuditLog, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentAuditLogMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentAuditLogMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetOrderID sets the "order_id" field. +func (m *PaymentAuditLogMutation) SetOrderID(s string) { + m.order_id = &s +} + +// OrderID returns the value of the "order_id" field in the mutation. +func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { + v := m.order_id + if v == nil { + return + } + return *v, true +} + +// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOrderID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOrderID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOrderID: %w", err) + } + return oldValue.OrderID, nil +} + +// ResetOrderID resets all changes to the "order_id" field. +func (m *PaymentAuditLogMutation) ResetOrderID() { + m.order_id = nil +} + +// SetAction sets the "action" field. +func (m *PaymentAuditLogMutation) SetAction(s string) { + m.action = &s +} + +// Action returns the value of the "action" field in the mutation. +func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { + v := m.action + if v == nil { + return + } + return *v, true +} + +// OldAction returns the old "action" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAction is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAction requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAction: %w", err) + } + return oldValue.Action, nil +} + +// ResetAction resets all changes to the "action" field. +func (m *PaymentAuditLogMutation) ResetAction() { + m.action = nil +} + +// SetDetail sets the "detail" field. +func (m *PaymentAuditLogMutation) SetDetail(s string) { + m.detail = &s +} + +// Detail returns the value of the "detail" field in the mutation. +func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { + v := m.detail + if v == nil { + return + } + return *v, true +} + +// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDetail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDetail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDetail: %w", err) + } + return oldValue.Detail, nil +} + +// ResetDetail resets all changes to the "detail" field. +func (m *PaymentAuditLogMutation) ResetDetail() { + m.detail = nil +} + +// SetOperator sets the "operator" field. +func (m *PaymentAuditLogMutation) SetOperator(s string) { + m.operator = &s +} + +// Operator returns the value of the "operator" field in the mutation. +func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { + v := m.operator + if v == nil { + return + } + return *v, true +} + +// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOperator is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOperator requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOperator: %w", err) + } + return oldValue.Operator, nil +} + +// ResetOperator resets all changes to the "operator" field. +func (m *PaymentAuditLogMutation) ResetOperator() { + m.operator = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentAuditLogMutation) ResetCreatedAt() { + m.created_at = nil +} + +// Where appends a list predicates to the PaymentAuditLogMutation builder. +func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentAuditLog, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentAuditLogMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentAuditLogMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentAuditLog). +func (m *PaymentAuditLogMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentAuditLogMutation) Fields() []string { + fields := make([]string, 0, 5) + if m.order_id != nil { + fields = append(fields, paymentauditlog.FieldOrderID) + } + if m.action != nil { + fields = append(fields, paymentauditlog.FieldAction) + } + if m.detail != nil { + fields = append(fields, paymentauditlog.FieldDetail) + } + if m.operator != nil { + fields = append(fields, paymentauditlog.FieldOperator) + } + if m.created_at != nil { + fields = append(fields, paymentauditlog.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OrderID() + case paymentauditlog.FieldAction: + return m.Action() + case paymentauditlog.FieldDetail: + return m.Detail() + case paymentauditlog.FieldOperator: + return m.Operator() + case paymentauditlog.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OldOrderID(ctx) + case paymentauditlog.FieldAction: + return m.OldAction(ctx) + case paymentauditlog.FieldDetail: + return m.OldDetail(ctx) + case paymentauditlog.FieldOperator: + return m.OldOperator(ctx) + case paymentauditlog.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentauditlog.FieldOrderID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderID(v) + return nil + case paymentauditlog.FieldAction: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAction(v) + return nil + case paymentauditlog.FieldDetail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDetail(v) + return nil + case paymentauditlog.FieldOperator: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOperator(v) + return nil + case paymentauditlog.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown PaymentAuditLog field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentAuditLogMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentAuditLogMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetField(name string) error { + switch name { + case paymentauditlog.FieldOrderID: + m.ResetOrderID() + return nil + case paymentauditlog.FieldAction: + m.ResetAction() + return nil + case paymentauditlog.FieldDetail: + m.ResetDetail() + return nil + case paymentauditlog.FieldOperator: + m.ResetOperator() + return nil + case paymentauditlog.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown PaymentAuditLog field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentAuditLogMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentAuditLogMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentAuditLogMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog edge %s", name) +} + +// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. +type PaymentOrderMutation struct { + config + op Op + typ string + id *int64 + user_email *string + user_name *string + user_notes *string + amount *float64 + addamount *float64 + pay_amount *float64 + addpay_amount *float64 + fee_rate *float64 + addfee_rate *float64 + recharge_code *string + out_trade_no *string + payment_type *string + payment_trade_no *string + pay_url *string + qr_code *string + qr_code_img *string + order_type *string + plan_id *int64 + addplan_id *int64 + subscription_group_id *int64 + addsubscription_group_id *int64 + subscription_days *int + addsubscription_days *int + provider_instance_id *string + provider_key *string + provider_snapshot *map[string]interface{} + status *string + refund_amount *float64 + addrefund_amount *float64 + refund_reason *string + refund_at *time.Time + force_refund *bool + refund_requested_at *time.Time + refund_request_reason *string + refund_requested_by *string + expires_at *time.Time + paid_at *time.Time + completed_at *time.Time + failed_at *time.Time + failed_reason *string + client_ip *string + src_host *string + src_url *string + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + done bool + oldValue func(context.Context) (*PaymentOrder, error) + predicates []predicate.PaymentOrder +} + +var _ ent.Mutation = (*PaymentOrderMutation)(nil) + +// paymentorderOption allows management of the mutation configuration using functional options. +type paymentorderOption func(*PaymentOrderMutation) + +// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. +func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { + m := &PaymentOrderMutation{ + config: c, + op: op, + typ: TypePaymentOrder, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPaymentOrderID sets the ID field of the mutation. +func withPaymentOrderID(id int64) paymentorderOption { + return func(m *PaymentOrderMutation) { + var ( + err error + once sync.Once + value *PaymentOrder + ) + m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentOrder.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPaymentOrder sets the old PaymentOrder of the mutation. +func withPaymentOrder(node *PaymentOrder) paymentorderOption { + return func(m *PaymentOrderMutation) { + m.oldValue = func(context.Context) (*PaymentOrder, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentOrderMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentOrderMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentOrderMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetUserID sets the "user_id" field. +func (m *PaymentOrderMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *PaymentOrderMutation) ResetUserID() { + m.user = nil +} + +// SetUserEmail sets the "user_email" field. +func (m *PaymentOrderMutation) SetUserEmail(s string) { + m.user_email = &s +} + +// UserEmail returns the value of the "user_email" field in the mutation. +func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { + v := m.user_email + if v == nil { + return + } + return *v, true +} + +// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) + } + return oldValue.UserEmail, nil +} + +// ResetUserEmail resets all changes to the "user_email" field. +func (m *PaymentOrderMutation) ResetUserEmail() { + m.user_email = nil +} + +// SetUserName sets the "user_name" field. +func (m *PaymentOrderMutation) SetUserName(s string) { + m.user_name = &s +} + +// UserName returns the value of the "user_name" field in the mutation. +func (m *PaymentOrderMutation) UserName() (r string, exists bool) { + v := m.user_name + if v == nil { + return + } + return *v, true +} + +// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserName: %w", err) + } + return oldValue.UserName, nil +} + +// ResetUserName resets all changes to the "user_name" field. +func (m *PaymentOrderMutation) ResetUserName() { + m.user_name = nil +} + +// SetUserNotes sets the "user_notes" field. +func (m *PaymentOrderMutation) SetUserNotes(s string) { + m.user_notes = &s +} + +// UserNotes returns the value of the "user_notes" field in the mutation. +func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { + v := m.user_notes + if v == nil { + return + } + return *v, true +} + +// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserNotes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) + } + return oldValue.UserNotes, nil +} + +// ClearUserNotes clears the value of the "user_notes" field. +func (m *PaymentOrderMutation) ClearUserNotes() { + m.user_notes = nil + m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} +} + +// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. +func (m *PaymentOrderMutation) UserNotesCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldUserNotes] + return ok +} + +// ResetUserNotes resets all changes to the "user_notes" field. +func (m *PaymentOrderMutation) ResetUserNotes() { + m.user_notes = nil + delete(m.clearedFields, paymentorder.FieldUserNotes) +} + +// SetAmount sets the "amount" field. +func (m *PaymentOrderMutation) SetAmount(f float64) { + m.amount = &f + m.addamount = nil +} + +// Amount returns the value of the "amount" field in the mutation. +func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { + v := m.amount + if v == nil { + return + } + return *v, true +} + +// OldAmount returns the old "amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAmount: %w", err) + } + return oldValue.Amount, nil +} + +// AddAmount adds f to the "amount" field. +func (m *PaymentOrderMutation) AddAmount(f float64) { + if m.addamount != nil { + *m.addamount += f + } else { + m.addamount = &f + } +} + +// AddedAmount returns the value that was added to the "amount" field in this mutation. +func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { + v := m.addamount + if v == nil { + return + } + return *v, true +} + +// ResetAmount resets all changes to the "amount" field. +func (m *PaymentOrderMutation) ResetAmount() { + m.amount = nil + m.addamount = nil +} + +// SetPayAmount sets the "pay_amount" field. +func (m *PaymentOrderMutation) SetPayAmount(f float64) { + m.pay_amount = &f + m.addpay_amount = nil +} + +// PayAmount returns the value of the "pay_amount" field in the mutation. +func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { + v := m.pay_amount + if v == nil { + return + } + return *v, true +} + +// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) + } + return oldValue.PayAmount, nil +} + +// AddPayAmount adds f to the "pay_amount" field. +func (m *PaymentOrderMutation) AddPayAmount(f float64) { + if m.addpay_amount != nil { + *m.addpay_amount += f + } else { + m.addpay_amount = &f + } +} + +// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { + v := m.addpay_amount + if v == nil { + return + } + return *v, true +} + +// ResetPayAmount resets all changes to the "pay_amount" field. +func (m *PaymentOrderMutation) ResetPayAmount() { + m.pay_amount = nil + m.addpay_amount = nil +} + +// SetFeeRate sets the "fee_rate" field. +func (m *PaymentOrderMutation) SetFeeRate(f float64) { + m.fee_rate = &f + m.addfee_rate = nil +} + +// FeeRate returns the value of the "fee_rate" field in the mutation. +func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { + v := m.fee_rate + if v == nil { + return + } + return *v, true +} + +// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeeRate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) + } + return oldValue.FeeRate, nil +} + +// AddFeeRate adds f to the "fee_rate" field. +func (m *PaymentOrderMutation) AddFeeRate(f float64) { + if m.addfee_rate != nil { + *m.addfee_rate += f + } else { + m.addfee_rate = &f + } +} + +// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. +func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { + v := m.addfee_rate + if v == nil { + return + } + return *v, true +} + +// ResetFeeRate resets all changes to the "fee_rate" field. +func (m *PaymentOrderMutation) ResetFeeRate() { + m.fee_rate = nil + m.addfee_rate = nil +} + +// SetRechargeCode sets the "recharge_code" field. +func (m *PaymentOrderMutation) SetRechargeCode(s string) { + m.recharge_code = &s +} + +// RechargeCode returns the value of the "recharge_code" field in the mutation. +func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { + v := m.recharge_code + if v == nil { + return + } + return *v, true +} + +// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRechargeCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) + } + return oldValue.RechargeCode, nil +} + +// ResetRechargeCode resets all changes to the "recharge_code" field. +func (m *PaymentOrderMutation) ResetRechargeCode() { + m.recharge_code = nil +} + +// SetOutTradeNo sets the "out_trade_no" field. +func (m *PaymentOrderMutation) SetOutTradeNo(s string) { + m.out_trade_no = &s +} + +// OutTradeNo returns the value of the "out_trade_no" field in the mutation. +func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { + v := m.out_trade_no + if v == nil { + return + } + return *v, true +} + +// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutTradeNo requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) + } + return oldValue.OutTradeNo, nil +} + +// ResetOutTradeNo resets all changes to the "out_trade_no" field. +func (m *PaymentOrderMutation) ResetOutTradeNo() { + m.out_trade_no = nil +} + +// SetPaymentType sets the "payment_type" field. +func (m *PaymentOrderMutation) SetPaymentType(s string) { + m.payment_type = &s +} + +// PaymentType returns the value of the "payment_type" field in the mutation. +func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { + v := m.payment_type + if v == nil { + return + } + return *v, true +} + +// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) + } + return oldValue.PaymentType, nil +} + +// ResetPaymentType resets all changes to the "payment_type" field. +func (m *PaymentOrderMutation) ResetPaymentType() { + m.payment_type = nil +} + +// SetPaymentTradeNo sets the "payment_trade_no" field. +func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { + m.payment_trade_no = &s +} + +// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. +func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { + v := m.payment_trade_no + if v == nil { + return + } + return *v, true +} + +// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) + } + return oldValue.PaymentTradeNo, nil +} + +// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. +func (m *PaymentOrderMutation) ResetPaymentTradeNo() { + m.payment_trade_no = nil +} + +// SetPayURL sets the "pay_url" field. +func (m *PaymentOrderMutation) SetPayURL(s string) { + m.pay_url = &s +} + +// PayURL returns the value of the "pay_url" field in the mutation. +func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { + v := m.pay_url + if v == nil { + return + } + return *v, true +} + +// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayURL: %w", err) + } + return oldValue.PayURL, nil +} + +// ClearPayURL clears the value of the "pay_url" field. +func (m *PaymentOrderMutation) ClearPayURL() { + m.pay_url = nil + m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +} + +// PayURLCleared returns if the "pay_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) PayURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPayURL] + return ok +} + +// ResetPayURL resets all changes to the "pay_url" field. +func (m *PaymentOrderMutation) ResetPayURL() { + m.pay_url = nil + delete(m.clearedFields, paymentorder.FieldPayURL) +} + +// SetQrCode sets the "qr_code" field. +func (m *PaymentOrderMutation) SetQrCode(s string) { + m.qr_code = &s +} + +// QrCode returns the value of the "qr_code" field in the mutation. +func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { + v := m.qr_code + if v == nil { + return + } + return *v, true +} + +// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCode: %w", err) + } + return oldValue.QrCode, nil +} + +// ClearQrCode clears the value of the "qr_code" field. +func (m *PaymentOrderMutation) ClearQrCode() { + m.qr_code = nil + m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +} + +// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCode] + return ok +} + +// ResetQrCode resets all changes to the "qr_code" field. +func (m *PaymentOrderMutation) ResetQrCode() { + m.qr_code = nil + delete(m.clearedFields, paymentorder.FieldQrCode) +} + +// SetQrCodeImg sets the "qr_code_img" field. +func (m *PaymentOrderMutation) SetQrCodeImg(s string) { + m.qr_code_img = &s +} + +// QrCodeImg returns the value of the "qr_code_img" field in the mutation. +func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { + v := m.qr_code_img + if v == nil { + return + } + return *v, true +} + +// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCodeImg requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + } + return oldValue.QrCodeImg, nil +} + +// ClearQrCodeImg clears the value of the "qr_code_img" field. +func (m *PaymentOrderMutation) ClearQrCodeImg() { + m.qr_code_img = nil + m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} +} + +// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeImgCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] + return ok +} + +// ResetQrCodeImg resets all changes to the "qr_code_img" field. +func (m *PaymentOrderMutation) ResetQrCodeImg() { + m.qr_code_img = nil + delete(m.clearedFields, paymentorder.FieldQrCodeImg) +} + +// SetOrderType sets the "order_type" field. +func (m *PaymentOrderMutation) SetOrderType(s string) { + m.order_type = &s +} + +// OrderType returns the value of the "order_type" field in the mutation. +func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { + v := m.order_type + if v == nil { + return + } + return *v, true +} + +// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOrderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOrderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOrderType: %w", err) + } + return oldValue.OrderType, nil +} + +// ResetOrderType resets all changes to the "order_type" field. +func (m *PaymentOrderMutation) ResetOrderType() { + m.order_type = nil +} + +// SetPlanID sets the "plan_id" field. +func (m *PaymentOrderMutation) SetPlanID(i int64) { + m.plan_id = &i + m.addplan_id = nil +} + +// PlanID returns the value of the "plan_id" field in the mutation. +func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { + v := m.plan_id + if v == nil { + return + } + return *v, true +} + +// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlanID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlanID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlanID: %w", err) + } + return oldValue.PlanID, nil +} + +// AddPlanID adds i to the "plan_id" field. +func (m *PaymentOrderMutation) AddPlanID(i int64) { + if m.addplan_id != nil { + *m.addplan_id += i + } else { + m.addplan_id = &i + } +} + +// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. +func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { + v := m.addplan_id + if v == nil { + return + } + return *v, true +} + +// ClearPlanID clears the value of the "plan_id" field. +func (m *PaymentOrderMutation) ClearPlanID() { + m.plan_id = nil + m.addplan_id = nil + m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +} + +// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) PlanIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPlanID] + return ok +} + +// ResetPlanID resets all changes to the "plan_id" field. +func (m *PaymentOrderMutation) ResetPlanID() { + m.plan_id = nil + m.addplan_id = nil + delete(m.clearedFields, paymentorder.FieldPlanID) +} + +// SetSubscriptionGroupID sets the "subscription_group_id" field. +func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { + m.subscription_group_id = &i + m.addsubscription_group_id = nil +} + +// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { + v := m.subscription_group_id + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) + } + return oldValue.SubscriptionGroupID, nil +} + +// AddSubscriptionGroupID adds i to the "subscription_group_id" field. +func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { + if m.addsubscription_group_id != nil { + *m.addsubscription_group_id += i + } else { + m.addsubscription_group_id = &i + } +} + +// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { + v := m.addsubscription_group_id + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. +func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} +} + +// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] + return ok +} + +// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. +func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +} + +// SetSubscriptionDays sets the "subscription_days" field. +func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { + m.subscription_days = &i + m.addsubscription_days = nil +} + +// SubscriptionDays returns the value of the "subscription_days" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { + v := m.subscription_days + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) + } + return oldValue.SubscriptionDays, nil +} + +// AddSubscriptionDays adds i to the "subscription_days" field. +func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { + if m.addsubscription_days != nil { + *m.addsubscription_days += i + } else { + m.addsubscription_days = &i + } +} + +// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { + v := m.addsubscription_days + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionDays clears the value of the "subscription_days" field. +func (m *PaymentOrderMutation) ClearSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} +} + +// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] + return ok +} + +// ResetSubscriptionDays resets all changes to the "subscription_days" field. +func (m *PaymentOrderMutation) ResetSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +} + +// SetProviderInstanceID sets the "provider_instance_id" field. +func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { + m.provider_instance_id = &s +} + +// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. +func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { + v := m.provider_instance_id + if v == nil { + return + } + return *v, true +} + +// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) + } + return oldValue.ProviderInstanceID, nil +} + +// ClearProviderInstanceID clears the value of the "provider_instance_id" field. +func (m *PaymentOrderMutation) ClearProviderInstanceID() { + m.provider_instance_id = nil + m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} +} + +// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] + return ok +} + +// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. +func (m *PaymentOrderMutation) ResetProviderInstanceID() { + m.provider_instance_id = nil + delete(m.clearedFields, paymentorder.FieldProviderInstanceID) +} + +// SetProviderKey sets the "provider_key" field. +func (m *PaymentOrderMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (m *PaymentOrderMutation) ClearProviderKey() { + m.provider_key = nil + m.clearedFields[paymentorder.FieldProviderKey] = struct{}{} +} + +// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderKeyCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderKey] + return ok +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentOrderMutation) ResetProviderKey() { + m.provider_key = nil + delete(m.clearedFields, paymentorder.FieldProviderKey) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) { + m.provider_snapshot = &value +} + +// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation. +func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) { + v := m.provider_snapshot + if v == nil { + return + } + return *v, true +} + +// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSnapshot requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err) + } + return oldValue.ProviderSnapshot, nil +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (m *PaymentOrderMutation) ClearProviderSnapshot() { + m.provider_snapshot = nil + m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{} +} + +// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot] + return ok +} + +// ResetProviderSnapshot resets all changes to the "provider_snapshot" field. +func (m *PaymentOrderMutation) ResetProviderSnapshot() { + m.provider_snapshot = nil + delete(m.clearedFields, paymentorder.FieldProviderSnapshot) +} + +// SetStatus sets the "status" field. +func (m *PaymentOrderMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *PaymentOrderMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *PaymentOrderMutation) ResetStatus() { + m.status = nil +} + +// SetRefundAmount sets the "refund_amount" field. +func (m *PaymentOrderMutation) SetRefundAmount(f float64) { + m.refund_amount = &f + m.addrefund_amount = nil +} + +// RefundAmount returns the value of the "refund_amount" field in the mutation. +func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { + v := m.refund_amount + if v == nil { + return + } + return *v, true +} + +// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + } + return oldValue.RefundAmount, nil +} + +// AddRefundAmount adds f to the "refund_amount" field. +func (m *PaymentOrderMutation) AddRefundAmount(f float64) { + if m.addrefund_amount != nil { + *m.addrefund_amount += f + } else { + m.addrefund_amount = &f + } +} + +// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { + v := m.addrefund_amount + if v == nil { + return + } + return *v, true +} + +// ResetRefundAmount resets all changes to the "refund_amount" field. +func (m *PaymentOrderMutation) ResetRefundAmount() { + m.refund_amount = nil + m.addrefund_amount = nil +} + +// SetRefundReason sets the "refund_reason" field. +func (m *PaymentOrderMutation) SetRefundReason(s string) { + m.refund_reason = &s +} + +// RefundReason returns the value of the "refund_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { + v := m.refund_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + } + return oldValue.RefundReason, nil +} + +// ClearRefundReason clears the value of the "refund_reason" field. +func (m *PaymentOrderMutation) ClearRefundReason() { + m.refund_reason = nil + m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +} + +// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundReason] + return ok +} + +// ResetRefundReason resets all changes to the "refund_reason" field. +func (m *PaymentOrderMutation) ResetRefundReason() { + m.refund_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundReason) +} + +// SetRefundAt sets the "refund_at" field. +func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { + m.refund_at = &t +} + +// RefundAt returns the value of the "refund_at" field in the mutation. +func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { + v := m.refund_at + if v == nil { + return + } + return *v, true +} + +// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + } + return oldValue.RefundAt, nil +} + +// ClearRefundAt clears the value of the "refund_at" field. +func (m *PaymentOrderMutation) ClearRefundAt() { + m.refund_at = nil + m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} +} + +// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundAt] + return ok +} + +// ResetRefundAt resets all changes to the "refund_at" field. +func (m *PaymentOrderMutation) ResetRefundAt() { + m.refund_at = nil + delete(m.clearedFields, paymentorder.FieldRefundAt) +} + +// SetForceRefund sets the "force_refund" field. +func (m *PaymentOrderMutation) SetForceRefund(b bool) { + m.force_refund = &b +} + +// ForceRefund returns the value of the "force_refund" field in the mutation. +func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { + v := m.force_refund + if v == nil { + return + } + return *v, true +} + +// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldForceRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + } + return oldValue.ForceRefund, nil +} + +// ResetForceRefund resets all changes to the "force_refund" field. +func (m *PaymentOrderMutation) ResetForceRefund() { + m.force_refund = nil +} + +// SetRefundRequestedAt sets the "refund_requested_at" field. +func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { + m.refund_requested_at = &t +} + +// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { + v := m.refund_requested_at + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + } + return oldValue.RefundRequestedAt, nil +} + +// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. +func (m *PaymentOrderMutation) ClearRefundRequestedAt() { + m.refund_requested_at = nil + m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} +} + +// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] + return ok +} + +// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. +func (m *PaymentOrderMutation) ResetRefundRequestedAt() { + m.refund_requested_at = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +} + +// SetRefundRequestReason sets the "refund_request_reason" field. +func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { + m.refund_request_reason = &s +} + +// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { + v := m.refund_request_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + } + return oldValue.RefundRequestReason, nil +} + +// ClearRefundRequestReason clears the value of the "refund_request_reason" field. +func (m *PaymentOrderMutation) ClearRefundRequestReason() { + m.refund_request_reason = nil + m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} +} + +// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] + return ok +} + +// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. +func (m *PaymentOrderMutation) ResetRefundRequestReason() { + m.refund_request_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +} + +// SetRefundRequestedBy sets the "refund_requested_by" field. +func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { + m.refund_requested_by = &s +} + +// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { + v := m.refund_requested_by + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + } + return oldValue.RefundRequestedBy, nil +} + +// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. +func (m *PaymentOrderMutation) ClearRefundRequestedBy() { + m.refund_requested_by = nil + m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +} + +// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] + return ok +} + +// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. +func (m *PaymentOrderMutation) ResetRefundRequestedBy() { + m.refund_requested_by = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PaymentOrderMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetPaidAt sets the "paid_at" field. +func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { + m.paid_at = &t +} + +// PaidAt returns the value of the "paid_at" field in the mutation. +func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { + v := m.paid_at + if v == nil { + return + } + return *v, true +} + +// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaidAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) + } + return oldValue.PaidAt, nil +} + +// ClearPaidAt clears the value of the "paid_at" field. +func (m *PaymentOrderMutation) ClearPaidAt() { + m.paid_at = nil + m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +} + +// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) PaidAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPaidAt] + return ok +} + +// ResetPaidAt resets all changes to the "paid_at" field. +func (m *PaymentOrderMutation) ResetPaidAt() { + m.paid_at = nil + delete(m.clearedFields, paymentorder.FieldPaidAt) +} + +// SetCompletedAt sets the "completed_at" field. +func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { + m.completed_at = &t +} + +// CompletedAt returns the value of the "completed_at" field in the mutation. +func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { + v := m.completed_at + if v == nil { + return + } + return *v, true +} + +// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) + } + return oldValue.CompletedAt, nil +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *PaymentOrderMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldCompletedAt] + return ok +} + +// ResetCompletedAt resets all changes to the "completed_at" field. +func (m *PaymentOrderMutation) ResetCompletedAt() { + m.completed_at = nil + delete(m.clearedFields, paymentorder.FieldCompletedAt) +} + +// SetFailedAt sets the "failed_at" field. +func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { + m.failed_at = &t +} + +// FailedAt returns the value of the "failed_at" field in the mutation. +func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { + v := m.failed_at + if v == nil { + return + } + return *v, true +} + +// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) + } + return oldValue.FailedAt, nil +} + +// ClearFailedAt clears the value of the "failed_at" field. +func (m *PaymentOrderMutation) ClearFailedAt() { + m.failed_at = nil + m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +} + +// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedAt] + return ok +} + +// ResetFailedAt resets all changes to the "failed_at" field. +func (m *PaymentOrderMutation) ResetFailedAt() { + m.failed_at = nil + delete(m.clearedFields, paymentorder.FieldFailedAt) +} + +// SetFailedReason sets the "failed_reason" field. +func (m *PaymentOrderMutation) SetFailedReason(s string) { + m.failed_reason = &s +} + +// FailedReason returns the value of the "failed_reason" field in the mutation. +func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { + v := m.failed_reason + if v == nil { + return + } + return *v, true +} + +// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) + } + return oldValue.FailedReason, nil +} + +// ClearFailedReason clears the value of the "failed_reason" field. +func (m *PaymentOrderMutation) ClearFailedReason() { + m.failed_reason = nil + m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} +} + +// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedReason] + return ok +} + +// ResetFailedReason resets all changes to the "failed_reason" field. +func (m *PaymentOrderMutation) ResetFailedReason() { + m.failed_reason = nil + delete(m.clearedFields, paymentorder.FieldFailedReason) +} + +// SetClientIP sets the "client_ip" field. +func (m *PaymentOrderMutation) SetClientIP(s string) { + m.client_ip = &s +} + +// ClientIP returns the value of the "client_ip" field in the mutation. +func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { + v := m.client_ip + if v == nil { + return + } + return *v, true +} + +// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClientIP is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClientIP requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClientIP: %w", err) + } + return oldValue.ClientIP, nil +} + +// ResetClientIP resets all changes to the "client_ip" field. +func (m *PaymentOrderMutation) ResetClientIP() { + m.client_ip = nil +} + +// SetSrcHost sets the "src_host" field. +func (m *PaymentOrderMutation) SetSrcHost(s string) { + m.src_host = &s +} + +// SrcHost returns the value of the "src_host" field in the mutation. +func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { + v := m.src_host + if v == nil { + return + } + return *v, true +} + +// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcHost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + } + return oldValue.SrcHost, nil +} + +// ResetSrcHost resets all changes to the "src_host" field. +func (m *PaymentOrderMutation) ResetSrcHost() { + m.src_host = nil +} + +// SetSrcURL sets the "src_url" field. +func (m *PaymentOrderMutation) SetSrcURL(s string) { + m.src_url = &s +} + +// SrcURL returns the value of the "src_url" field in the mutation. +func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { + v := m.src_url + if v == nil { + return + } + return *v, true +} + +// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + } + return oldValue.SrcURL, nil +} + +// ClearSrcURL clears the value of the "src_url" field. +func (m *PaymentOrderMutation) ClearSrcURL() { + m.src_url = nil + m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} +} + +// SrcURLCleared returns if the "src_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) SrcURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSrcURL] + return ok +} + +// ResetSrcURL resets all changes to the "src_url" field. +func (m *PaymentOrderMutation) ResetSrcURL() { + m.src_url = nil + delete(m.clearedFields, paymentorder.FieldSrcURL) +} + +// SetCreatedAt sets the "created_at" field. +func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentOrderMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PaymentOrderMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *PaymentOrderMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[paymentorder.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *PaymentOrderMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *PaymentOrderMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *PaymentOrderMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the PaymentOrderMutation builder. +func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentOrder, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentOrderMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentOrderMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentOrder). +func (m *PaymentOrderMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentOrderMutation) Fields() []string { + fields := make([]string, 0, 39) + if m.user != nil { + fields = append(fields, paymentorder.FieldUserID) + } + if m.user_email != nil { + fields = append(fields, paymentorder.FieldUserEmail) + } + if m.user_name != nil { + fields = append(fields, paymentorder.FieldUserName) + } + if m.user_notes != nil { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.amount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.pay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.fee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.recharge_code != nil { + fields = append(fields, paymentorder.FieldRechargeCode) + } + if m.out_trade_no != nil { + fields = append(fields, paymentorder.FieldOutTradeNo) + } + if m.payment_type != nil { + fields = append(fields, paymentorder.FieldPaymentType) + } + if m.payment_trade_no != nil { + fields = append(fields, paymentorder.FieldPaymentTradeNo) + } + if m.pay_url != nil { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.qr_code != nil { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.qr_code_img != nil { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.order_type != nil { + fields = append(fields, paymentorder.FieldOrderType) + } + if m.plan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.subscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.subscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.provider_instance_id != nil { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.provider_key != nil { + fields = append(fields, paymentorder.FieldProviderKey) + } + if m.provider_snapshot != nil { + fields = append(fields, paymentorder.FieldProviderSnapshot) + } + if m.status != nil { + fields = append(fields, paymentorder.FieldStatus) + } + if m.refund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + if m.refund_reason != nil { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.refund_at != nil { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.force_refund != nil { + fields = append(fields, paymentorder.FieldForceRefund) + } + if m.refund_requested_at != nil { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.refund_request_reason != nil { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.refund_requested_by != nil { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.expires_at != nil { + fields = append(fields, paymentorder.FieldExpiresAt) + } + if m.paid_at != nil { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.completed_at != nil { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.failed_at != nil { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.failed_reason != nil { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.client_ip != nil { + fields = append(fields, paymentorder.FieldClientIP) + } + if m.src_host != nil { + fields = append(fields, paymentorder.FieldSrcHost) + } + if m.src_url != nil { + fields = append(fields, paymentorder.FieldSrcURL) + } + if m.created_at != nil { + fields = append(fields, paymentorder.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, paymentorder.FieldUpdatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldUserID: + return m.UserID() + case paymentorder.FieldUserEmail: + return m.UserEmail() + case paymentorder.FieldUserName: + return m.UserName() + case paymentorder.FieldUserNotes: + return m.UserNotes() + case paymentorder.FieldAmount: + return m.Amount() + case paymentorder.FieldPayAmount: + return m.PayAmount() + case paymentorder.FieldFeeRate: + return m.FeeRate() + case paymentorder.FieldRechargeCode: + return m.RechargeCode() + case paymentorder.FieldOutTradeNo: + return m.OutTradeNo() + case paymentorder.FieldPaymentType: + return m.PaymentType() + case paymentorder.FieldPaymentTradeNo: + return m.PaymentTradeNo() + case paymentorder.FieldPayURL: + return m.PayURL() + case paymentorder.FieldQrCode: + return m.QrCode() + case paymentorder.FieldQrCodeImg: + return m.QrCodeImg() + case paymentorder.FieldOrderType: + return m.OrderType() + case paymentorder.FieldPlanID: + return m.PlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.SubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.SubscriptionDays() + case paymentorder.FieldProviderInstanceID: + return m.ProviderInstanceID() + case paymentorder.FieldProviderKey: + return m.ProviderKey() + case paymentorder.FieldProviderSnapshot: + return m.ProviderSnapshot() + case paymentorder.FieldStatus: + return m.Status() + case paymentorder.FieldRefundAmount: + return m.RefundAmount() + case paymentorder.FieldRefundReason: + return m.RefundReason() + case paymentorder.FieldRefundAt: + return m.RefundAt() + case paymentorder.FieldForceRefund: + return m.ForceRefund() + case paymentorder.FieldRefundRequestedAt: + return m.RefundRequestedAt() + case paymentorder.FieldRefundRequestReason: + return m.RefundRequestReason() + case paymentorder.FieldRefundRequestedBy: + return m.RefundRequestedBy() + case paymentorder.FieldExpiresAt: + return m.ExpiresAt() + case paymentorder.FieldPaidAt: + return m.PaidAt() + case paymentorder.FieldCompletedAt: + return m.CompletedAt() + case paymentorder.FieldFailedAt: + return m.FailedAt() + case paymentorder.FieldFailedReason: + return m.FailedReason() + case paymentorder.FieldClientIP: + return m.ClientIP() + case paymentorder.FieldSrcHost: + return m.SrcHost() + case paymentorder.FieldSrcURL: + return m.SrcURL() + case paymentorder.FieldCreatedAt: + return m.CreatedAt() + case paymentorder.FieldUpdatedAt: + return m.UpdatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentorder.FieldUserID: + return m.OldUserID(ctx) + case paymentorder.FieldUserEmail: + return m.OldUserEmail(ctx) + case paymentorder.FieldUserName: + return m.OldUserName(ctx) + case paymentorder.FieldUserNotes: + return m.OldUserNotes(ctx) + case paymentorder.FieldAmount: + return m.OldAmount(ctx) + case paymentorder.FieldPayAmount: + return m.OldPayAmount(ctx) + case paymentorder.FieldFeeRate: + return m.OldFeeRate(ctx) + case paymentorder.FieldRechargeCode: + return m.OldRechargeCode(ctx) + case paymentorder.FieldOutTradeNo: + return m.OldOutTradeNo(ctx) + case paymentorder.FieldPaymentType: + return m.OldPaymentType(ctx) + case paymentorder.FieldPaymentTradeNo: + return m.OldPaymentTradeNo(ctx) + case paymentorder.FieldPayURL: + return m.OldPayURL(ctx) + case paymentorder.FieldQrCode: + return m.OldQrCode(ctx) + case paymentorder.FieldQrCodeImg: + return m.OldQrCodeImg(ctx) + case paymentorder.FieldOrderType: + return m.OldOrderType(ctx) + case paymentorder.FieldPlanID: + return m.OldPlanID(ctx) + case paymentorder.FieldSubscriptionGroupID: + return m.OldSubscriptionGroupID(ctx) + case paymentorder.FieldSubscriptionDays: + return m.OldSubscriptionDays(ctx) + case paymentorder.FieldProviderInstanceID: + return m.OldProviderInstanceID(ctx) + case paymentorder.FieldProviderKey: + return m.OldProviderKey(ctx) + case paymentorder.FieldProviderSnapshot: + return m.OldProviderSnapshot(ctx) + case paymentorder.FieldStatus: + return m.OldStatus(ctx) + case paymentorder.FieldRefundAmount: + return m.OldRefundAmount(ctx) + case paymentorder.FieldRefundReason: + return m.OldRefundReason(ctx) + case paymentorder.FieldRefundAt: + return m.OldRefundAt(ctx) + case paymentorder.FieldForceRefund: + return m.OldForceRefund(ctx) + case paymentorder.FieldRefundRequestedAt: + return m.OldRefundRequestedAt(ctx) + case paymentorder.FieldRefundRequestReason: + return m.OldRefundRequestReason(ctx) + case paymentorder.FieldRefundRequestedBy: + return m.OldRefundRequestedBy(ctx) + case paymentorder.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case paymentorder.FieldPaidAt: + return m.OldPaidAt(ctx) + case paymentorder.FieldCompletedAt: + return m.OldCompletedAt(ctx) + case paymentorder.FieldFailedAt: + return m.OldFailedAt(ctx) + case paymentorder.FieldFailedReason: + return m.OldFailedReason(ctx) + case paymentorder.FieldClientIP: + return m.OldClientIP(ctx) + case paymentorder.FieldSrcHost: + return m.OldSrcHost(ctx) + case paymentorder.FieldSrcURL: + return m.OldSrcURL(ctx) + case paymentorder.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case paymentorder.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + } + return nil, fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case paymentorder.FieldUserEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserEmail(v) + return nil + case paymentorder.FieldUserName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserName(v) + return nil + case paymentorder.FieldUserNotes: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserNotes(v) + return nil + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeeRate(v) + return nil + case paymentorder.FieldRechargeCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRechargeCode(v) + return nil + case paymentorder.FieldOutTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutTradeNo(v) + return nil + case paymentorder.FieldPaymentType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentType(v) + return nil + case paymentorder.FieldPaymentTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentTradeNo(v) + return nil + case paymentorder.FieldPayURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayURL(v) + return nil + case paymentorder.FieldQrCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCode(v) + return nil + case paymentorder.FieldQrCodeImg: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCodeImg(v) + return nil + case paymentorder.FieldOrderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderType(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionDays(v) + return nil + case paymentorder.FieldProviderInstanceID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderInstanceID(v) + return nil + case paymentorder.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case paymentorder.FieldProviderSnapshot: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSnapshot(v) + return nil + case paymentorder.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundAmount(v) + return nil + case paymentorder.FieldRefundReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundReason(v) + return nil + case paymentorder.FieldRefundAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundAt(v) + return nil + case paymentorder.FieldForceRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetForceRefund(v) + return nil + case paymentorder.FieldRefundRequestedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedAt(v) + return nil + case paymentorder.FieldRefundRequestReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestReason(v) + return nil + case paymentorder.FieldRefundRequestedBy: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedBy(v) + return nil + case paymentorder.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case paymentorder.FieldPaidAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaidAt(v) + return nil + case paymentorder.FieldCompletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletedAt(v) + return nil + case paymentorder.FieldFailedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedAt(v) + return nil + case paymentorder.FieldFailedReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedReason(v) + return nil + case paymentorder.FieldClientIP: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClientIP(v) + return nil + case paymentorder.FieldSrcHost: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcHost(v) + return nil + case paymentorder.FieldSrcURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcURL(v) + return nil + case paymentorder.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case paymentorder.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentOrderMutation) AddedFields() []string { + var fields []string + if m.addamount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.addpay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.addfee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.addplan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.addsubscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.addsubscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.addrefund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldAmount: + return m.AddedAmount() + case paymentorder.FieldPayAmount: + return m.AddedPayAmount() + case paymentorder.FieldFeeRate: + return m.AddedFeeRate() + case paymentorder.FieldPlanID: + return m.AddedPlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.AddedSubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.AddedSubscriptionDays() + case paymentorder.FieldRefundAmount: + return m.AddedRefundAmount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFeeRate(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionDays(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRefundAmount(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentOrderMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(paymentorder.FieldUserNotes) { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.FieldCleared(paymentorder.FieldPayURL) { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.FieldCleared(paymentorder.FieldQrCode) { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.FieldCleared(paymentorder.FieldQrCodeImg) { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.FieldCleared(paymentorder.FieldPlanID) { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionDays) { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.FieldCleared(paymentorder.FieldProviderInstanceID) { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.FieldCleared(paymentorder.FieldProviderKey) { + fields = append(fields, paymentorder.FieldProviderKey) + } + if m.FieldCleared(paymentorder.FieldProviderSnapshot) { + fields = append(fields, paymentorder.FieldProviderSnapshot) + } + if m.FieldCleared(paymentorder.FieldRefundReason) { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.FieldCleared(paymentorder.FieldRefundAt) { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestReason) { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.FieldCleared(paymentorder.FieldPaidAt) { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.FieldCleared(paymentorder.FieldCompletedAt) { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.FieldCleared(paymentorder.FieldFailedAt) { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.FieldCleared(paymentorder.FieldFailedReason) { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.FieldCleared(paymentorder.FieldSrcURL) { + fields = append(fields, paymentorder.FieldSrcURL) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentOrderMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ClearField(name string) error { + switch name { + case paymentorder.FieldUserNotes: + m.ClearUserNotes() + return nil + case paymentorder.FieldPayURL: + m.ClearPayURL() + return nil + case paymentorder.FieldQrCode: + m.ClearQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ClearQrCodeImg() + return nil + case paymentorder.FieldPlanID: + m.ClearPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ClearSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ClearSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ClearProviderInstanceID() + return nil + case paymentorder.FieldProviderKey: + m.ClearProviderKey() + return nil + case paymentorder.FieldProviderSnapshot: + m.ClearProviderSnapshot() + return nil + case paymentorder.FieldRefundReason: + m.ClearRefundReason() + return nil + case paymentorder.FieldRefundAt: + m.ClearRefundAt() + return nil + case paymentorder.FieldRefundRequestedAt: + m.ClearRefundRequestedAt() + return nil + case paymentorder.FieldRefundRequestReason: + m.ClearRefundRequestReason() + return nil + case paymentorder.FieldRefundRequestedBy: + m.ClearRefundRequestedBy() + return nil + case paymentorder.FieldPaidAt: + m.ClearPaidAt() + return nil + case paymentorder.FieldCompletedAt: + m.ClearCompletedAt() + return nil + case paymentorder.FieldFailedAt: + m.ClearFailedAt() + return nil + case paymentorder.FieldFailedReason: + m.ClearFailedReason() + return nil + case paymentorder.FieldSrcURL: + m.ClearSrcURL() + return nil + } + return fmt.Errorf("unknown PaymentOrder nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ResetField(name string) error { + switch name { + case paymentorder.FieldUserID: + m.ResetUserID() + return nil + case paymentorder.FieldUserEmail: + m.ResetUserEmail() + return nil + case paymentorder.FieldUserName: + m.ResetUserName() + return nil + case paymentorder.FieldUserNotes: + m.ResetUserNotes() + return nil + case paymentorder.FieldAmount: + m.ResetAmount() + return nil + case paymentorder.FieldPayAmount: + m.ResetPayAmount() + return nil + case paymentorder.FieldFeeRate: + m.ResetFeeRate() + return nil + case paymentorder.FieldRechargeCode: + m.ResetRechargeCode() + return nil + case paymentorder.FieldOutTradeNo: + m.ResetOutTradeNo() + return nil + case paymentorder.FieldPaymentType: + m.ResetPaymentType() + return nil + case paymentorder.FieldPaymentTradeNo: + m.ResetPaymentTradeNo() + return nil + case paymentorder.FieldPayURL: + m.ResetPayURL() + return nil + case paymentorder.FieldQrCode: + m.ResetQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ResetQrCodeImg() + return nil + case paymentorder.FieldOrderType: + m.ResetOrderType() + return nil + case paymentorder.FieldPlanID: + m.ResetPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ResetSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ResetSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ResetProviderInstanceID() + return nil + case paymentorder.FieldProviderKey: + m.ResetProviderKey() + return nil + case paymentorder.FieldProviderSnapshot: + m.ResetProviderSnapshot() + return nil + case paymentorder.FieldStatus: + m.ResetStatus() + return nil + case paymentorder.FieldRefundAmount: + m.ResetRefundAmount() + return nil + case paymentorder.FieldRefundReason: + m.ResetRefundReason() + return nil + case paymentorder.FieldRefundAt: + m.ResetRefundAt() + return nil + case paymentorder.FieldForceRefund: + m.ResetForceRefund() + return nil + case paymentorder.FieldRefundRequestedAt: + m.ResetRefundRequestedAt() + return nil + case paymentorder.FieldRefundRequestReason: + m.ResetRefundRequestReason() + return nil + case paymentorder.FieldRefundRequestedBy: + m.ResetRefundRequestedBy() + return nil + case paymentorder.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case paymentorder.FieldPaidAt: + m.ResetPaidAt() + return nil + case paymentorder.FieldCompletedAt: + m.ResetCompletedAt() + return nil + case paymentorder.FieldFailedAt: + m.ResetFailedAt() + return nil + case paymentorder.FieldFailedReason: + m.ResetFailedReason() + return nil + case paymentorder.FieldClientIP: + m.ResetClientIP() + return nil + case paymentorder.FieldSrcHost: + m.ResetSrcHost() + return nil + case paymentorder.FieldSrcURL: + m.ResetSrcURL() + return nil + case paymentorder.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case paymentorder.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentOrderMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { + switch name { + case paymentorder.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentOrderMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentOrderMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentOrderMutation) EdgeCleared(name string) bool { + switch name { + case paymentorder.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentOrderMutation) ClearEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown PaymentOrder unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentOrderMutation) ResetEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ResetUser() + return nil + } + return fmt.Errorf("unknown PaymentOrder edge %s", name) +} + +// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. +type PaymentProviderInstanceMutation struct { + config + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance +} + +var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) + +// paymentproviderinstanceOption allows management of the mutation configuration using functional options. +type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) + +// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. +func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { + m := &PaymentProviderInstanceMutation{ + config: c, + op: op, + typ: TypePaymentProviderInstance, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPaymentProviderInstanceID sets the ID field of the mutation. +func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + var ( + err error + once sync.Once + value *PaymentProviderInstance + ) + m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. +func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentProviderInstanceMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetProviderKey sets the "provider_key" field. +func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key if v == nil { return } return *v, true } -// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { +func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClientIP is only allowed on UpdateOne operations") + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClientIP requires an ID field in the mutation") + return v, errors.New("OldProviderKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldClientIP: %w", err) + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) } - return oldValue.ClientIP, nil + return oldValue.ProviderKey, nil } -// ResetClientIP resets all changes to the "client_ip" field. -func (m *PaymentOrderMutation) ResetClientIP() { - m.client_ip = nil +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentProviderInstanceMutation) ResetProviderKey() { + m.provider_key = nil } -// SetSrcHost sets the "src_host" field. -func (m *PaymentOrderMutation) SetSrcHost(s string) { - m.src_host = &s +// SetName sets the "name" field. +func (m *PaymentProviderInstanceMutation) SetName(s string) { + m.name = &s } -// SrcHost returns the value of the "src_host" field in the mutation. -func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { - v := m.src_host +// Name returns the value of the "name" field in the mutation. +func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { + v := m.name if v == nil { return } return *v, true } -// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { +func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + return v, errors.New("OldName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcHost requires an ID field in the mutation") + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.SrcHost, nil + return oldValue.Name, nil } -// ResetSrcHost resets all changes to the "src_host" field. -func (m *PaymentOrderMutation) ResetSrcHost() { - m.src_host = nil +// ResetName resets all changes to the "name" field. +func (m *PaymentProviderInstanceMutation) ResetName() { + m.name = nil } -// SetSrcURL sets the "src_url" field. -func (m *PaymentOrderMutation) SetSrcURL(s string) { - m.src_url = &s +// SetConfig sets the "config" field. +func (m *PaymentProviderInstanceMutation) SetConfig(s string) { + m._config = &s } -// SrcURL returns the value of the "src_url" field in the mutation. -func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { - v := m.src_url +// Config returns the value of the "config" field in the mutation. +func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { + v := m._config if v == nil { return } return *v, true } -// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { +func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + return v, errors.New("OldConfig is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcURL requires an ID field in the mutation") + return v, errors.New("OldConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConfig: %w", err) + } + return oldValue.Config, nil +} + +// ResetConfig resets all changes to the "config" field. +func (m *PaymentProviderInstanceMutation) ResetConfig() { + m._config = nil +} + +// SetSupportedTypes sets the "supported_types" field. +func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { + m.supported_types = &s +} + +// SupportedTypes returns the value of the "supported_types" field in the mutation. +func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { + v := m.supported_types + if v == nil { + return + } + return *v, true +} + +// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + } + return oldValue.SupportedTypes, nil +} + +// ResetSupportedTypes resets all changes to the "supported_types" field. +func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { + m.supported_types = nil +} + +// SetEnabled sets the "enabled" field. +func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *PaymentProviderInstanceMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPaymentMode sets the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { + m.payment_mode = &s +} + +// PaymentMode returns the value of the "payment_mode" field in the mutation. +func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { + v := m.payment_mode + if v == nil { + return + } + return *v, true +} + +// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + } + return oldValue.PaymentMode, nil +} + +// ResetPaymentMode resets all changes to the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { + m.payment_mode = nil +} + +// SetSortOrder sets the "sort_order" field. +func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + +// SetLimits sets the "limits" field. +func (m *PaymentProviderInstanceMutation) SetLimits(s string) { + m.limits = &s +} + +// Limits returns the value of the "limits" field in the mutation. +func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { + v := m.limits + if v == nil { + return + } + return *v, true +} + +// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLimits is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLimits requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLimits: %w", err) + } + return oldValue.Limits, nil +} + +// ResetLimits resets all changes to the "limits" field. +func (m *PaymentProviderInstanceMutation) ResetLimits() { + m.limits = nil +} + +// SetRefundEnabled sets the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { + m.refund_enabled = &b +} + +// RefundEnabled returns the value of the "refund_enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { + v := m.refund_enabled + if v == nil { + return + } + return *v, true +} + +// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) + } + return oldValue.RefundEnabled, nil +} + +// ResetRefundEnabled resets all changes to the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { + m.refund_enabled = nil +} + +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return + } + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) } - return oldValue.SrcURL, nil -} - -// ClearSrcURL clears the value of the "src_url" field. -func (m *PaymentOrderMutation) ClearSrcURL() { - m.src_url = nil - m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} -} - -// SrcURLCleared returns if the "src_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) SrcURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSrcURL] - return ok + return oldValue.AllowUserRefund, nil } -// ResetSrcURL resets all changes to the "src_url" field. -func (m *PaymentOrderMutation) ResetSrcURL() { - m.src_url = nil - delete(m.clearedFields, paymentorder.FieldSrcURL) +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil } // SetCreatedAt sets the "created_at" field. -func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { +func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { +func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -14539,10 +23696,10 @@ func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -14557,17 +23714,17 @@ func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, e } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentOrderMutation) ResetCreatedAt() { +func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { +func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -14575,10 +23732,10 @@ func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -14593,46 +23750,19 @@ func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, e } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentOrderMutation) ResetUpdatedAt() { +func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { m.updated_at = nil } -// ClearUser clears the "user" edge to the User entity. -func (m *PaymentOrderMutation) ClearUser() { - m.cleareduser = true - m.clearedFields[paymentorder.FieldUserID] = struct{}{} -} - -// UserCleared reports if the "user" edge to the User entity was cleared. -func (m *PaymentOrderMutation) UserCleared() bool { - return m.cleareduser -} - -// UserIDs returns the "user" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// UserID instead. It exists only for internal usage by the builders. -func (m *PaymentOrderMutation) UserIDs() (ids []int64) { - if id := m.user; id != nil { - ids = append(ids, *id) - } - return -} - -// ResetUser resets all changes to the "user" edge. -func (m *PaymentOrderMutation) ResetUser() { - m.user = nil - m.cleareduser = false -} - -// Where appends a list predicates to the PaymentOrderMutation builder. -func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { +// Where appends a list predicates to the PaymentProviderInstanceMutation builder. +func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, +// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentOrder, len(ps)) +func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentProviderInstance, len(ps)) for i := range ps { p[i] = ps[i] } @@ -14640,135 +23770,60 @@ func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *PaymentOrderMutation) Op() Op { +func (m *PaymentProviderInstanceMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *PaymentOrderMutation) SetOp(op Op) { +func (m *PaymentProviderInstanceMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (PaymentOrder). -func (m *PaymentOrderMutation) Type() string { +// Type returns the node type of this mutation (PaymentProviderInstance). +func (m *PaymentProviderInstanceMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *PaymentOrderMutation) Fields() []string { - fields := make([]string, 0, 37) - if m.user != nil { - fields = append(fields, paymentorder.FieldUserID) - } - if m.user_email != nil { - fields = append(fields, paymentorder.FieldUserEmail) - } - if m.user_name != nil { - fields = append(fields, paymentorder.FieldUserName) - } - if m.user_notes != nil { - fields = append(fields, paymentorder.FieldUserNotes) - } - if m.amount != nil { - fields = append(fields, paymentorder.FieldAmount) - } - if m.pay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) - } - if m.fee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) - } - if m.recharge_code != nil { - fields = append(fields, paymentorder.FieldRechargeCode) - } - if m.out_trade_no != nil { - fields = append(fields, paymentorder.FieldOutTradeNo) - } - if m.payment_type != nil { - fields = append(fields, paymentorder.FieldPaymentType) - } - if m.payment_trade_no != nil { - fields = append(fields, paymentorder.FieldPaymentTradeNo) - } - if m.pay_url != nil { - fields = append(fields, paymentorder.FieldPayURL) - } - if m.qr_code != nil { - fields = append(fields, paymentorder.FieldQrCode) - } - if m.qr_code_img != nil { - fields = append(fields, paymentorder.FieldQrCodeImg) - } - if m.order_type != nil { - fields = append(fields, paymentorder.FieldOrderType) - } - if m.plan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) - } - if m.subscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) - } - if m.subscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) - } - if m.provider_instance_id != nil { - fields = append(fields, paymentorder.FieldProviderInstanceID) - } - if m.status != nil { - fields = append(fields, paymentorder.FieldStatus) - } - if m.refund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) - } - if m.refund_reason != nil { - fields = append(fields, paymentorder.FieldRefundReason) - } - if m.refund_at != nil { - fields = append(fields, paymentorder.FieldRefundAt) - } - if m.force_refund != nil { - fields = append(fields, paymentorder.FieldForceRefund) - } - if m.refund_requested_at != nil { - fields = append(fields, paymentorder.FieldRefundRequestedAt) - } - if m.refund_request_reason != nil { - fields = append(fields, paymentorder.FieldRefundRequestReason) +func (m *PaymentProviderInstanceMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.provider_key != nil { + fields = append(fields, paymentproviderinstance.FieldProviderKey) } - if m.refund_requested_by != nil { - fields = append(fields, paymentorder.FieldRefundRequestedBy) + if m.name != nil { + fields = append(fields, paymentproviderinstance.FieldName) } - if m.expires_at != nil { - fields = append(fields, paymentorder.FieldExpiresAt) + if m._config != nil { + fields = append(fields, paymentproviderinstance.FieldConfig) } - if m.paid_at != nil { - fields = append(fields, paymentorder.FieldPaidAt) + if m.supported_types != nil { + fields = append(fields, paymentproviderinstance.FieldSupportedTypes) } - if m.completed_at != nil { - fields = append(fields, paymentorder.FieldCompletedAt) + if m.enabled != nil { + fields = append(fields, paymentproviderinstance.FieldEnabled) } - if m.failed_at != nil { - fields = append(fields, paymentorder.FieldFailedAt) + if m.payment_mode != nil { + fields = append(fields, paymentproviderinstance.FieldPaymentMode) } - if m.failed_reason != nil { - fields = append(fields, paymentorder.FieldFailedReason) + if m.sort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) } - if m.client_ip != nil { - fields = append(fields, paymentorder.FieldClientIP) + if m.limits != nil { + fields = append(fields, paymentproviderinstance.FieldLimits) } - if m.src_host != nil { - fields = append(fields, paymentorder.FieldSrcHost) + if m.refund_enabled != nil { + fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } - if m.src_url != nil { - fields = append(fields, paymentorder.FieldSrcURL) + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) } if m.created_at != nil { - fields = append(fields, paymentorder.FieldCreatedAt) + fields = append(fields, paymentproviderinstance.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, paymentorder.FieldUpdatedAt) + fields = append(fields, paymentproviderinstance.FieldUpdatedAt) } return fields } @@ -14776,81 +23831,31 @@ func (m *PaymentOrderMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { +func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { switch name { - case paymentorder.FieldUserID: - return m.UserID() - case paymentorder.FieldUserEmail: - return m.UserEmail() - case paymentorder.FieldUserName: - return m.UserName() - case paymentorder.FieldUserNotes: - return m.UserNotes() - case paymentorder.FieldAmount: - return m.Amount() - case paymentorder.FieldPayAmount: - return m.PayAmount() - case paymentorder.FieldFeeRate: - return m.FeeRate() - case paymentorder.FieldRechargeCode: - return m.RechargeCode() - case paymentorder.FieldOutTradeNo: - return m.OutTradeNo() - case paymentorder.FieldPaymentType: - return m.PaymentType() - case paymentorder.FieldPaymentTradeNo: - return m.PaymentTradeNo() - case paymentorder.FieldPayURL: - return m.PayURL() - case paymentorder.FieldQrCode: - return m.QrCode() - case paymentorder.FieldQrCodeImg: - return m.QrCodeImg() - case paymentorder.FieldOrderType: - return m.OrderType() - case paymentorder.FieldPlanID: - return m.PlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.SubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.SubscriptionDays() - case paymentorder.FieldProviderInstanceID: - return m.ProviderInstanceID() - case paymentorder.FieldStatus: - return m.Status() - case paymentorder.FieldRefundAmount: - return m.RefundAmount() - case paymentorder.FieldRefundReason: - return m.RefundReason() - case paymentorder.FieldRefundAt: - return m.RefundAt() - case paymentorder.FieldForceRefund: - return m.ForceRefund() - case paymentorder.FieldRefundRequestedAt: - return m.RefundRequestedAt() - case paymentorder.FieldRefundRequestReason: - return m.RefundRequestReason() - case paymentorder.FieldRefundRequestedBy: - return m.RefundRequestedBy() - case paymentorder.FieldExpiresAt: - return m.ExpiresAt() - case paymentorder.FieldPaidAt: - return m.PaidAt() - case paymentorder.FieldCompletedAt: - return m.CompletedAt() - case paymentorder.FieldFailedAt: - return m.FailedAt() - case paymentorder.FieldFailedReason: - return m.FailedReason() - case paymentorder.FieldClientIP: - return m.ClientIP() - case paymentorder.FieldSrcHost: - return m.SrcHost() - case paymentorder.FieldSrcURL: - return m.SrcURL() - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldProviderKey: + return m.ProviderKey() + case paymentproviderinstance.FieldName: + return m.Name() + case paymentproviderinstance.FieldConfig: + return m.Config() + case paymentproviderinstance.FieldSupportedTypes: + return m.SupportedTypes() + case paymentproviderinstance.FieldEnabled: + return m.Enabled() + case paymentproviderinstance.FieldPaymentMode: + return m.PaymentMode() + case paymentproviderinstance.FieldSortOrder: + return m.SortOrder() + case paymentproviderinstance.FieldLimits: + return m.Limits() + case paymentproviderinstance.FieldRefundEnabled: + return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() + case paymentproviderinstance.FieldCreatedAt: return m.CreatedAt() - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: return m.UpdatedAt() } return nil, false @@ -14859,344 +23864,119 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentorder.FieldUserID: - return m.OldUserID(ctx) - case paymentorder.FieldUserEmail: - return m.OldUserEmail(ctx) - case paymentorder.FieldUserName: - return m.OldUserName(ctx) - case paymentorder.FieldUserNotes: - return m.OldUserNotes(ctx) - case paymentorder.FieldAmount: - return m.OldAmount(ctx) - case paymentorder.FieldPayAmount: - return m.OldPayAmount(ctx) - case paymentorder.FieldFeeRate: - return m.OldFeeRate(ctx) - case paymentorder.FieldRechargeCode: - return m.OldRechargeCode(ctx) - case paymentorder.FieldOutTradeNo: - return m.OldOutTradeNo(ctx) - case paymentorder.FieldPaymentType: - return m.OldPaymentType(ctx) - case paymentorder.FieldPaymentTradeNo: - return m.OldPaymentTradeNo(ctx) - case paymentorder.FieldPayURL: - return m.OldPayURL(ctx) - case paymentorder.FieldQrCode: - return m.OldQrCode(ctx) - case paymentorder.FieldQrCodeImg: - return m.OldQrCodeImg(ctx) - case paymentorder.FieldOrderType: - return m.OldOrderType(ctx) - case paymentorder.FieldPlanID: - return m.OldPlanID(ctx) - case paymentorder.FieldSubscriptionGroupID: - return m.OldSubscriptionGroupID(ctx) - case paymentorder.FieldSubscriptionDays: - return m.OldSubscriptionDays(ctx) - case paymentorder.FieldProviderInstanceID: - return m.OldProviderInstanceID(ctx) - case paymentorder.FieldStatus: - return m.OldStatus(ctx) - case paymentorder.FieldRefundAmount: - return m.OldRefundAmount(ctx) - case paymentorder.FieldRefundReason: - return m.OldRefundReason(ctx) - case paymentorder.FieldRefundAt: - return m.OldRefundAt(ctx) - case paymentorder.FieldForceRefund: - return m.OldForceRefund(ctx) - case paymentorder.FieldRefundRequestedAt: - return m.OldRefundRequestedAt(ctx) - case paymentorder.FieldRefundRequestReason: - return m.OldRefundRequestReason(ctx) - case paymentorder.FieldRefundRequestedBy: - return m.OldRefundRequestedBy(ctx) - case paymentorder.FieldExpiresAt: - return m.OldExpiresAt(ctx) - case paymentorder.FieldPaidAt: - return m.OldPaidAt(ctx) - case paymentorder.FieldCompletedAt: - return m.OldCompletedAt(ctx) - case paymentorder.FieldFailedAt: - return m.OldFailedAt(ctx) - case paymentorder.FieldFailedReason: - return m.OldFailedReason(ctx) - case paymentorder.FieldClientIP: - return m.OldClientIP(ctx) - case paymentorder.FieldSrcHost: - return m.OldSrcHost(ctx) - case paymentorder.FieldSrcURL: - return m.OldSrcURL(ctx) - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldProviderKey: + return m.OldProviderKey(ctx) + case paymentproviderinstance.FieldName: + return m.OldName(ctx) + case paymentproviderinstance.FieldConfig: + return m.OldConfig(ctx) + case paymentproviderinstance.FieldSupportedTypes: + return m.OldSupportedTypes(ctx) + case paymentproviderinstance.FieldEnabled: + return m.OldEnabled(ctx) + case paymentproviderinstance.FieldPaymentMode: + return m.OldPaymentMode(ctx) + case paymentproviderinstance.FieldSortOrder: + return m.OldSortOrder(ctx) + case paymentproviderinstance.FieldLimits: + return m.OldLimits(ctx) + case paymentproviderinstance.FieldRefundEnabled: + return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) + case paymentproviderinstance.FieldCreatedAt: return m.OldCreatedAt(ctx) - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: return m.OldUpdatedAt(ctx) } - return nil, fmt.Errorf("unknown PaymentOrder field %s", name) + return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) } // SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { - switch name { - case paymentorder.FieldUserID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserID(v) - return nil - case paymentorder.FieldUserEmail: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserEmail(v) - return nil - case paymentorder.FieldUserName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserName(v) - return nil - case paymentorder.FieldUserNotes: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserNotes(v) - return nil - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFeeRate(v) - return nil - case paymentorder.FieldRechargeCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRechargeCode(v) - return nil - case paymentorder.FieldOutTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOutTradeNo(v) - return nil - case paymentorder.FieldPaymentType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentType(v) - return nil - case paymentorder.FieldPaymentTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentTradeNo(v) - return nil - case paymentorder.FieldPayURL: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayURL(v) - return nil - case paymentorder.FieldQrCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCode(v) - return nil - case paymentorder.FieldQrCodeImg: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCodeImg(v) - return nil - case paymentorder.FieldOrderType: +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentproviderinstance.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetOrderType(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionDays(v) + m.SetProviderKey(v) return nil - case paymentorder.FieldProviderInstanceID: + case paymentproviderinstance.FieldName: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetProviderInstanceID(v) + m.SetName(v) return nil - case paymentorder.FieldStatus: + case paymentproviderinstance.FieldConfig: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundAmount(v) + m.SetConfig(v) return nil - case paymentorder.FieldRefundReason: + case paymentproviderinstance.FieldSupportedTypes: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundReason(v) - return nil - case paymentorder.FieldRefundAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundAt(v) + m.SetSupportedTypes(v) return nil - case paymentorder.FieldForceRefund: + case paymentproviderinstance.FieldEnabled: v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetForceRefund(v) - return nil - case paymentorder.FieldRefundRequestedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestedAt(v) - return nil - case paymentorder.FieldRefundRequestReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestReason(v) + m.SetEnabled(v) return nil - case paymentorder.FieldRefundRequestedBy: + case paymentproviderinstance.FieldPaymentMode: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundRequestedBy(v) - return nil - case paymentorder.FieldExpiresAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetExpiresAt(v) - return nil - case paymentorder.FieldPaidAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaidAt(v) - return nil - case paymentorder.FieldCompletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCompletedAt(v) - return nil - case paymentorder.FieldFailedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFailedAt(v) + m.SetPaymentMode(v) return nil - case paymentorder.FieldFailedReason: - v, ok := value.(string) + case paymentproviderinstance.FieldSortOrder: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetFailedReason(v) + m.SetSortOrder(v) return nil - case paymentorder.FieldClientIP: + case paymentproviderinstance.FieldLimits: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetClientIP(v) + m.SetLimits(v) return nil - case paymentorder.FieldSrcHost: - v, ok := value.(string) + case paymentproviderinstance.FieldRefundEnabled: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSrcHost(v) + m.SetRefundEnabled(v) return nil - case paymentorder.FieldSrcURL: - v, ok := value.(string) + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSrcURL(v) + m.SetAllowUserRefund(v) return nil - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) @@ -15204,33 +23984,15 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { m.SetUpdatedAt(v) return nil } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *PaymentOrderMutation) AddedFields() []string { +func (m *PaymentProviderInstanceMutation) AddedFields() []string { var fields []string - if m.addamount != nil { - fields = append(fields, paymentorder.FieldAmount) - } - if m.addpay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) - } - if m.addfee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) - } - if m.addplan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) - } - if m.addsubscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) - } - if m.addsubscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) - } - if m.addrefund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) + if m.addsort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) } return fields } @@ -15238,22 +24000,10 @@ func (m *PaymentOrderMutation) AddedFields() []string { // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { +func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { switch name { - case paymentorder.FieldAmount: - return m.AddedAmount() - case paymentorder.FieldPayAmount: - return m.AddedPayAmount() - case paymentorder.FieldFeeRate: - return m.AddedFeeRate() - case paymentorder.FieldPlanID: - return m.AddedPlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.AddedSubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.AddedSubscriptionDays() - case paymentorder.FieldRefundAmount: - return m.AddedRefundAmount() + case paymentproviderinstance.FieldSortOrder: + return m.AddedSortOrder() } return nil, false } @@ -15261,420 +24011,177 @@ func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { +func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { switch name { - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFeeRate(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: + case paymentproviderinstance.FieldSortOrder: v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddSubscriptionDays(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRefundAmount(v) + m.AddSortOrder(v) return nil } - return fmt.Errorf("unknown PaymentOrder numeric field %s", name) + return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *PaymentOrderMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(paymentorder.FieldUserNotes) { - fields = append(fields, paymentorder.FieldUserNotes) - } - if m.FieldCleared(paymentorder.FieldPayURL) { - fields = append(fields, paymentorder.FieldPayURL) - } - if m.FieldCleared(paymentorder.FieldQrCode) { - fields = append(fields, paymentorder.FieldQrCode) - } - if m.FieldCleared(paymentorder.FieldQrCodeImg) { - fields = append(fields, paymentorder.FieldQrCodeImg) - } - if m.FieldCleared(paymentorder.FieldPlanID) { - fields = append(fields, paymentorder.FieldPlanID) - } - if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) - } - if m.FieldCleared(paymentorder.FieldSubscriptionDays) { - fields = append(fields, paymentorder.FieldSubscriptionDays) - } - if m.FieldCleared(paymentorder.FieldProviderInstanceID) { - fields = append(fields, paymentorder.FieldProviderInstanceID) - } - if m.FieldCleared(paymentorder.FieldRefundReason) { - fields = append(fields, paymentorder.FieldRefundReason) - } - if m.FieldCleared(paymentorder.FieldRefundAt) { - fields = append(fields, paymentorder.FieldRefundAt) - } - if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { - fields = append(fields, paymentorder.FieldRefundRequestedAt) - } - if m.FieldCleared(paymentorder.FieldRefundRequestReason) { - fields = append(fields, paymentorder.FieldRefundRequestReason) - } - if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { - fields = append(fields, paymentorder.FieldRefundRequestedBy) - } - if m.FieldCleared(paymentorder.FieldPaidAt) { - fields = append(fields, paymentorder.FieldPaidAt) - } - if m.FieldCleared(paymentorder.FieldCompletedAt) { - fields = append(fields, paymentorder.FieldCompletedAt) - } - if m.FieldCleared(paymentorder.FieldFailedAt) { - fields = append(fields, paymentorder.FieldFailedAt) - } - if m.FieldCleared(paymentorder.FieldFailedReason) { - fields = append(fields, paymentorder.FieldFailedReason) - } - if m.FieldCleared(paymentorder.FieldSrcURL) { - fields = append(fields, paymentorder.FieldSrcURL) - } - return fields +func (m *PaymentProviderInstanceMutation) ClearedFields() []string { + return nil } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *PaymentOrderMutation) FieldCleared(name string) bool { +func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ClearField(name string) error { - switch name { - case paymentorder.FieldUserNotes: - m.ClearUserNotes() - return nil - case paymentorder.FieldPayURL: - m.ClearPayURL() - return nil - case paymentorder.FieldQrCode: - m.ClearQrCode() - return nil - case paymentorder.FieldQrCodeImg: - m.ClearQrCodeImg() - return nil - case paymentorder.FieldPlanID: - m.ClearPlanID() - return nil - case paymentorder.FieldSubscriptionGroupID: - m.ClearSubscriptionGroupID() - return nil - case paymentorder.FieldSubscriptionDays: - m.ClearSubscriptionDays() - return nil - case paymentorder.FieldProviderInstanceID: - m.ClearProviderInstanceID() - return nil - case paymentorder.FieldRefundReason: - m.ClearRefundReason() - return nil - case paymentorder.FieldRefundAt: - m.ClearRefundAt() - return nil - case paymentorder.FieldRefundRequestedAt: - m.ClearRefundRequestedAt() - return nil - case paymentorder.FieldRefundRequestReason: - m.ClearRefundRequestReason() - return nil - case paymentorder.FieldRefundRequestedBy: - m.ClearRefundRequestedBy() - return nil - case paymentorder.FieldPaidAt: - m.ClearPaidAt() - return nil - case paymentorder.FieldCompletedAt: - m.ClearCompletedAt() - return nil - case paymentorder.FieldFailedAt: - m.ClearFailedAt() - return nil - case paymentorder.FieldFailedReason: - m.ClearFailedReason() - return nil - case paymentorder.FieldSrcURL: - m.ClearSrcURL() - return nil - } - return fmt.Errorf("unknown PaymentOrder nullable field %s", name) +func (m *PaymentProviderInstanceMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ResetField(name string) error { +func (m *PaymentProviderInstanceMutation) ResetField(name string) error { switch name { - case paymentorder.FieldUserID: - m.ResetUserID() - return nil - case paymentorder.FieldUserEmail: - m.ResetUserEmail() - return nil - case paymentorder.FieldUserName: - m.ResetUserName() - return nil - case paymentorder.FieldUserNotes: - m.ResetUserNotes() - return nil - case paymentorder.FieldAmount: - m.ResetAmount() - return nil - case paymentorder.FieldPayAmount: - m.ResetPayAmount() - return nil - case paymentorder.FieldFeeRate: - m.ResetFeeRate() - return nil - case paymentorder.FieldRechargeCode: - m.ResetRechargeCode() - return nil - case paymentorder.FieldOutTradeNo: - m.ResetOutTradeNo() - return nil - case paymentorder.FieldPaymentType: - m.ResetPaymentType() - return nil - case paymentorder.FieldPaymentTradeNo: - m.ResetPaymentTradeNo() - return nil - case paymentorder.FieldPayURL: - m.ResetPayURL() - return nil - case paymentorder.FieldQrCode: - m.ResetQrCode() - return nil - case paymentorder.FieldQrCodeImg: - m.ResetQrCodeImg() - return nil - case paymentorder.FieldOrderType: - m.ResetOrderType() - return nil - case paymentorder.FieldPlanID: - m.ResetPlanID() - return nil - case paymentorder.FieldSubscriptionGroupID: - m.ResetSubscriptionGroupID() - return nil - case paymentorder.FieldSubscriptionDays: - m.ResetSubscriptionDays() - return nil - case paymentorder.FieldProviderInstanceID: - m.ResetProviderInstanceID() - return nil - case paymentorder.FieldStatus: - m.ResetStatus() - return nil - case paymentorder.FieldRefundAmount: - m.ResetRefundAmount() - return nil - case paymentorder.FieldRefundReason: - m.ResetRefundReason() - return nil - case paymentorder.FieldRefundAt: - m.ResetRefundAt() - return nil - case paymentorder.FieldForceRefund: - m.ResetForceRefund() - return nil - case paymentorder.FieldRefundRequestedAt: - m.ResetRefundRequestedAt() - return nil - case paymentorder.FieldRefundRequestReason: - m.ResetRefundRequestReason() + case paymentproviderinstance.FieldProviderKey: + m.ResetProviderKey() return nil - case paymentorder.FieldRefundRequestedBy: - m.ResetRefundRequestedBy() + case paymentproviderinstance.FieldName: + m.ResetName() return nil - case paymentorder.FieldExpiresAt: - m.ResetExpiresAt() + case paymentproviderinstance.FieldConfig: + m.ResetConfig() return nil - case paymentorder.FieldPaidAt: - m.ResetPaidAt() + case paymentproviderinstance.FieldSupportedTypes: + m.ResetSupportedTypes() return nil - case paymentorder.FieldCompletedAt: - m.ResetCompletedAt() + case paymentproviderinstance.FieldEnabled: + m.ResetEnabled() return nil - case paymentorder.FieldFailedAt: - m.ResetFailedAt() + case paymentproviderinstance.FieldPaymentMode: + m.ResetPaymentMode() return nil - case paymentorder.FieldFailedReason: - m.ResetFailedReason() + case paymentproviderinstance.FieldSortOrder: + m.ResetSortOrder() return nil - case paymentorder.FieldClientIP: - m.ResetClientIP() + case paymentproviderinstance.FieldLimits: + m.ResetLimits() return nil - case paymentorder.FieldSrcHost: - m.ResetSrcHost() + case paymentproviderinstance.FieldRefundEnabled: + m.ResetRefundEnabled() return nil - case paymentorder.FieldSrcURL: - m.ResetSrcURL() + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() return nil - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: m.ResetUpdatedAt() return nil } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentOrderMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.user != nil { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) AddedEdges() []string { + edges := make([]string, 0, 0) return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { - switch name { - case paymentorder.EdgeUser: - if id := m.user; id != nil { - return []ent.Value{*id} - } - } +func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentOrderMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) +func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { +func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentOrderMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.cleareduser { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentOrderMutation) EdgeCleared(name string) bool { - switch name { - case paymentorder.EdgeUser: - return m.cleareduser - } +func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { return false } - -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *PaymentOrderMutation) ClearEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ClearUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder unique edge %s", name) + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *PaymentOrderMutation) ResetEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ResetUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder edge %s", name) +func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) } -// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. -type PaymentProviderInstanceMutation struct { +// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph. +type PendingAuthSessionMutation struct { config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - allow_user_refund *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + session_token *string + intent *string + provider_type *string + provider_key *string + provider_subject *string + redirect_to *string + resolved_email *string + registration_password_hash *string + upstream_identity_claims *map[string]interface{} + local_flow_state *map[string]interface{} + browser_session_key *string + completion_code_hash *string + completion_code_expires_at *time.Time + email_verified_at *time.Time + password_verified_at *time.Time + totp_verified_at *time.Time + expires_at *time.Time + consumed_at *time.Time + clearedFields map[string]struct{} + target_user *int64 + clearedtarget_user bool + adoption_decision *int64 + clearedadoption_decision bool + done bool + oldValue func(context.Context) (*PendingAuthSession, error) + predicates []predicate.PendingAuthSession } -var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) +var _ ent.Mutation = (*PendingAuthSessionMutation)(nil) -// paymentproviderinstanceOption allows management of the mutation configuration using functional options. -type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) +// pendingauthsessionOption allows management of the mutation configuration using functional options. +type pendingauthsessionOption func(*PendingAuthSessionMutation) -// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. -func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { - m := &PaymentProviderInstanceMutation{ +// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity. +func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation { + m := &PendingAuthSessionMutation{ config: c, op: op, - typ: TypePaymentProviderInstance, + typ: TypePendingAuthSession, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -15683,20 +24190,20 @@ func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentprovider return m } -// withPaymentProviderInstanceID sets the ID field of the mutation. -func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { +// withPendingAuthSessionID sets the ID field of the mutation. +func withPendingAuthSessionID(id int64) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { var ( err error once sync.Once - value *PaymentProviderInstance + value *PendingAuthSession ) - m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + value, err = m.Client().PendingAuthSession.Get(ctx, id) } }) return value, err @@ -15705,10 +24212,10 @@ func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { } } -// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. -func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { - m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { +// withPendingAuthSession sets the old PendingAuthSession of the mutation. +func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + m.oldValue = func(context.Context) (*PendingAuthSession, error) { return node, nil } m.id = &node.ID @@ -15717,7 +24224,7 @@ func withPaymentProviderInstance(node *PaymentProviderInstance) paymentprovideri // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentProviderInstanceMutation) Client() *Client { +func (m PendingAuthSessionMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -15725,7 +24232,7 @@ func (m PaymentProviderInstanceMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { +func (m PendingAuthSessionMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -15736,7 +24243,7 @@ func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { +func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -15747,7 +24254,7 @@ func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -15756,19 +24263,199 @@ func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, err } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) + return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } +// SetCreatedAt sets the "created_at" field. +func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PendingAuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PendingAuthSessionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetSessionToken sets the "session_token" field. +func (m *PendingAuthSessionMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *PendingAuthSessionMutation) ResetSessionToken() { + m.session_token = nil +} + +// SetIntent sets the "intent" field. +func (m *PendingAuthSessionMutation) SetIntent(s string) { + m.intent = &s +} + +// Intent returns the value of the "intent" field in the mutation. +func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) { + v := m.intent + if v == nil { + return + } + return *v, true +} + +// OldIntent returns the old "intent" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIntent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIntent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIntent: %w", err) + } + return oldValue.Intent, nil +} + +// ResetIntent resets all changes to the "intent" field. +func (m *PendingAuthSessionMutation) ResetIntent() { + m.intent = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *PendingAuthSessionMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *PendingAuthSessionMutation) ResetProviderType() { + m.provider_type = nil +} + // SetProviderKey sets the "provider_key" field. -func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { +func (m *PendingAuthSessionMutation) SetProviderKey(s string) { m.provider_key = &s } // ProviderKey returns the value of the "provider_key" field in the mutation. -func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { +func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) { v := m.provider_key if v == nil { return @@ -15776,10 +24463,10 @@ func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) return *v, true } -// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } @@ -15794,435 +24481,703 @@ func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v } // ResetProviderKey resets all changes to the "provider_key" field. -func (m *PaymentProviderInstanceMutation) ResetProviderKey() { +func (m *PendingAuthSessionMutation) ResetProviderKey() { m.provider_key = nil } -// SetName sets the "name" field. -func (m *PaymentProviderInstanceMutation) SetName(s string) { - m.name = &s +// SetProviderSubject sets the "provider_subject" field. +func (m *PendingAuthSessionMutation) SetProviderSubject(s string) { + m.provider_subject = &s } -// Name returns the value of the "name" field in the mutation. -func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { - v := m.name +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldProviderSubject requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) } - return oldValue.Name, nil + return oldValue.ProviderSubject, nil } -// ResetName resets all changes to the "name" field. -func (m *PaymentProviderInstanceMutation) ResetName() { - m.name = nil +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *PendingAuthSessionMutation) ResetProviderSubject() { + m.provider_subject = nil } -// SetConfig sets the "config" field. -func (m *PaymentProviderInstanceMutation) SetConfig(s string) { - m._config = &s +// SetTargetUserID sets the "target_user_id" field. +func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) { + m.target_user = &i } -// Config returns the value of the "config" field in the mutation. -func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { - v := m._config +// TargetUserID returns the value of the "target_user_id" field in the mutation. +func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) { + v := m.target_user if v == nil { return } return *v, true } -// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldConfig is only allowed on UpdateOne operations") + return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldConfig requires an ID field in the mutation") + return v, errors.New("OldTargetUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldConfig: %w", err) + return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err) } - return oldValue.Config, nil + return oldValue.TargetUserID, nil } -// ResetConfig resets all changes to the "config" field. -func (m *PaymentProviderInstanceMutation) ResetConfig() { - m._config = nil +// ClearTargetUserID clears the value of the "target_user_id" field. +func (m *PendingAuthSessionMutation) ClearTargetUserID() { + m.target_user = nil + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} } -// SetSupportedTypes sets the "supported_types" field. -func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { - m.supported_types = &s +// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID] + return ok } -// SupportedTypes returns the value of the "supported_types" field in the mutation. -func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { - v := m.supported_types +// ResetTargetUserID resets all changes to the "target_user_id" field. +func (m *PendingAuthSessionMutation) ResetTargetUserID() { + m.target_user = nil + delete(m.clearedFields, pendingauthsession.FieldTargetUserID) +} + +// SetRedirectTo sets the "redirect_to" field. +func (m *PendingAuthSessionMutation) SetRedirectTo(s string) { + m.redirect_to = &s +} + +// RedirectTo returns the value of the "redirect_to" field in the mutation. +func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) { + v := m.redirect_to if v == nil { return } return *v, true } -// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + return v, errors.New("OldRedirectTo requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err) } - return oldValue.SupportedTypes, nil + return oldValue.RedirectTo, nil } -// ResetSupportedTypes resets all changes to the "supported_types" field. -func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { - m.supported_types = nil +// ResetRedirectTo resets all changes to the "redirect_to" field. +func (m *PendingAuthSessionMutation) ResetRedirectTo() { + m.redirect_to = nil } -// SetEnabled sets the "enabled" field. -func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { - m.enabled = &b +// SetResolvedEmail sets the "resolved_email" field. +func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) { + m.resolved_email = &s } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// ResolvedEmail returns the value of the "resolved_email" field in the mutation. +func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) { + v := m.resolved_email if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldResolvedEmail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err) } - return oldValue.Enabled, nil + return oldValue.ResolvedEmail, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *PaymentProviderInstanceMutation) ResetEnabled() { - m.enabled = nil +// ResetResolvedEmail resets all changes to the "resolved_email" field. +func (m *PendingAuthSessionMutation) ResetResolvedEmail() { + m.resolved_email = nil } -// SetPaymentMode sets the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { - m.payment_mode = &s +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) { + m.registration_password_hash = &s } -// PaymentMode returns the value of the "payment_mode" field in the mutation. -func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { - v := m.payment_mode +// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation. +func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) { + v := m.registration_password_hash if v == nil { return } return *v, true } -// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentMode requires an ID field in the mutation") + return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err) } - return oldValue.PaymentMode, nil + return oldValue.RegistrationPasswordHash, nil } -// ResetPaymentMode resets all changes to the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { - m.payment_mode = nil +// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() { + m.registration_password_hash = nil } -// SetSortOrder sets the "sort_order" field. -func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) { + m.upstream_identity_claims = &value } -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { - v := m.sort_order +// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation. +func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) { + v := m.upstream_identity_claims if v == nil { return } return *v, true } -// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { +func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") + return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err) } - return oldValue.SortOrder, nil + return oldValue.UpstreamIdentityClaims, nil } -// AddSortOrder adds i to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i +// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() { + m.upstream_identity_claims = nil +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) { + m.local_flow_state = &value +} + +// LocalFlowState returns the value of the "local_flow_state" field in the mutation. +func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) { + v := m.local_flow_state + if v == nil { + return } + return *v, true } -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order +// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLocalFlowState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err) + } + return oldValue.LocalFlowState, nil +} + +// ResetLocalFlowState resets all changes to the "local_flow_state" field. +func (m *PendingAuthSessionMutation) ResetLocalFlowState() { + m.local_flow_state = nil +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) { + m.browser_session_key = &s +} + +// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation. +func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) { + v := m.browser_session_key if v == nil { return } return *v, true } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err) + } + return oldValue.BrowserSessionKey, nil } -// SetLimits sets the "limits" field. -func (m *PaymentProviderInstanceMutation) SetLimits(s string) { - m.limits = &s +// ResetBrowserSessionKey resets all changes to the "browser_session_key" field. +func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() { + m.browser_session_key = nil } -// Limits returns the value of the "limits" field in the mutation. -func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { - v := m.limits +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) { + m.completion_code_hash = &s +} + +// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) { + v := m.completion_code_hash if v == nil { return } return *v, true } -// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLimits is only allowed on UpdateOne operations") + return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLimits requires an ID field in the mutation") + return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldLimits: %w", err) + return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err) } - return oldValue.Limits, nil + return oldValue.CompletionCodeHash, nil } -// ResetLimits resets all changes to the "limits" field. -func (m *PaymentProviderInstanceMutation) ResetLimits() { - m.limits = nil +// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() { + m.completion_code_hash = nil } -// SetRefundEnabled sets the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { - m.refund_enabled = &b +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) { + m.completion_code_expires_at = &t } -// RefundEnabled returns the value of the "refund_enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { - v := m.refund_enabled +// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) { + v := m.completion_code_expires_at if v == nil { return } return *v, true } -// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundEnabled requires an ID field in the mutation") + return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err) } - return oldValue.RefundEnabled, nil + return oldValue.CompletionCodeExpiresAt, nil } -// ResetRefundEnabled resets all changes to the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { - m.refund_enabled = nil +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{} } -// SetAllowUserRefund sets the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { - m.allow_user_refund = &b +// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] + return ok } -// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. -func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { - v := m.allow_user_refund +// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) { + m.email_verified_at = &t +} + +// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) { + v := m.email_verified_at if v == nil { return } return *v, true } -// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err) } - return oldValue.AllowUserRefund, nil + return oldValue.EmailVerifiedAt, nil } -// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { - m.allow_user_refund = nil +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() { + m.email_verified_at = nil + m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{} } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] + return ok } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() { + m.email_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) { + m.password_verified_at = &t +} + +// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) { + v := m.password_verified_at if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.PasswordVerifiedAt, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { - m.created_at = nil +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() { + m.password_verified_at = nil + m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{} } -// SetUpdatedAt sets the "updated_at" field. -func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] + return ok } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() { + m.password_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) { + m.totp_verified_at = &t +} + +// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) { + v := m.totp_verified_at if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err) + } + return oldValue.TotpVerifiedAt, nil +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() { + m.totp_verified_at = nil + m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{} +} + +// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] + return ok +} + +// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() { + m.totp_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PendingAuthSessionMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetConsumedAt sets the "consumed_at" field. +func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) { + m.consumed_at = &t +} + +// ConsumedAt returns the value of the "consumed_at" field in the mutation. +func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) { + v := m.consumed_at + if v == nil { + return + } + return *v, true +} + +// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConsumedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err) + } + return oldValue.ConsumedAt, nil +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (m *PendingAuthSessionMutation) ClearConsumedAt() { + m.consumed_at = nil + m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{} +} + +// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt] + return ok +} + +// ResetConsumedAt resets all changes to the "consumed_at" field. +func (m *PendingAuthSessionMutation) ResetConsumedAt() { + m.consumed_at = nil + delete(m.clearedFields, pendingauthsession.FieldConsumedAt) +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (m *PendingAuthSessionMutation) ClearTargetUser() { + m.clearedtarget_user = true + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserCleared reports if the "target_user" edge to the User entity was cleared. +func (m *PendingAuthSessionMutation) TargetUserCleared() bool { + return m.TargetUserIDCleared() || m.clearedtarget_user +} + +// TargetUserIDs returns the "target_user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// TargetUserID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) { + if id := m.target_user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetTargetUser resets all changes to the "target_user" edge. +func (m *PendingAuthSessionMutation) ResetTargetUser() { + m.target_user = nil + m.clearedtarget_user = false +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id. +func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) { + m.adoption_decision = &id +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (m *PendingAuthSessionMutation) ClearAdoptionDecision() { + m.clearedadoption_decision = true +} + +// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared. +func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool { + return m.clearedadoption_decision +} + +// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation. +func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) { + if m.adoption_decision != nil { + return *m.adoption_decision, true + } + return +} + +// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AdoptionDecisionID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) { + if id := m.adoption_decision; id != nil { + ids = append(ids, *id) } - return oldValue.UpdatedAt, nil + return } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { - m.updated_at = nil +// ResetAdoptionDecision resets all changes to the "adoption_decision" edge. +func (m *PendingAuthSessionMutation) ResetAdoptionDecision() { + m.adoption_decision = nil + m.clearedadoption_decision = false } -// Where appends a list predicates to the PaymentProviderInstanceMutation builder. -func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { +// Where appends a list predicates to the PendingAuthSessionMutation builder. +func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, +// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentProviderInstance, len(ps)) +func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PendingAuthSession, len(ps)) for i := range ps { p[i] = ps[i] } @@ -16230,60 +25185,87 @@ func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *PaymentProviderInstanceMutation) Op() Op { +func (m *PendingAuthSessionMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *PaymentProviderInstanceMutation) SetOp(op Op) { +func (m *PendingAuthSessionMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (PaymentProviderInstance). -func (m *PaymentProviderInstanceMutation) Type() string { +// Type returns the node type of this mutation (PendingAuthSession). +func (m *PendingAuthSessionMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 12) +func (m *PendingAuthSessionMutation) Fields() []string { + fields := make([]string, 0, 21) + if m.created_at != nil { + fields = append(fields, pendingauthsession.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, pendingauthsession.FieldUpdatedAt) + } + if m.session_token != nil { + fields = append(fields, pendingauthsession.FieldSessionToken) + } + if m.intent != nil { + fields = append(fields, pendingauthsession.FieldIntent) + } + if m.provider_type != nil { + fields = append(fields, pendingauthsession.FieldProviderType) + } if m.provider_key != nil { - fields = append(fields, paymentproviderinstance.FieldProviderKey) + fields = append(fields, pendingauthsession.FieldProviderKey) } - if m.name != nil { - fields = append(fields, paymentproviderinstance.FieldName) + if m.provider_subject != nil { + fields = append(fields, pendingauthsession.FieldProviderSubject) } - if m._config != nil { - fields = append(fields, paymentproviderinstance.FieldConfig) + if m.target_user != nil { + fields = append(fields, pendingauthsession.FieldTargetUserID) } - if m.supported_types != nil { - fields = append(fields, paymentproviderinstance.FieldSupportedTypes) + if m.redirect_to != nil { + fields = append(fields, pendingauthsession.FieldRedirectTo) } - if m.enabled != nil { - fields = append(fields, paymentproviderinstance.FieldEnabled) + if m.resolved_email != nil { + fields = append(fields, pendingauthsession.FieldResolvedEmail) } - if m.payment_mode != nil { - fields = append(fields, paymentproviderinstance.FieldPaymentMode) + if m.registration_password_hash != nil { + fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash) } - if m.sort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) + if m.upstream_identity_claims != nil { + fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims) } - if m.limits != nil { - fields = append(fields, paymentproviderinstance.FieldLimits) + if m.local_flow_state != nil { + fields = append(fields, pendingauthsession.FieldLocalFlowState) } - if m.refund_enabled != nil { - fields = append(fields, paymentproviderinstance.FieldRefundEnabled) + if m.browser_session_key != nil { + fields = append(fields, pendingauthsession.FieldBrowserSessionKey) } - if m.allow_user_refund != nil { - fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + if m.completion_code_hash != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeHash) } - if m.created_at != nil { - fields = append(fields, paymentproviderinstance.FieldCreatedAt) + if m.completion_code_expires_at != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) } - if m.updated_at != nil { - fields = append(fields, paymentproviderinstance.FieldUpdatedAt) + if m.email_verified_at != nil { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.password_verified_at != nil { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.totp_verified_at != nil { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.expires_at != nil { + fields = append(fields, pendingauthsession.FieldExpiresAt) + } + if m.consumed_at != nil { + fields = append(fields, pendingauthsession.FieldConsumedAt) } return fields } @@ -16291,32 +25273,50 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.ProviderKey() - case paymentproviderinstance.FieldName: - return m.Name() - case paymentproviderinstance.FieldConfig: - return m.Config() - case paymentproviderinstance.FieldSupportedTypes: - return m.SupportedTypes() - case paymentproviderinstance.FieldEnabled: - return m.Enabled() - case paymentproviderinstance.FieldPaymentMode: - return m.PaymentMode() - case paymentproviderinstance.FieldSortOrder: - return m.SortOrder() - case paymentproviderinstance.FieldLimits: - return m.Limits() - case paymentproviderinstance.FieldRefundEnabled: - return m.RefundEnabled() - case paymentproviderinstance.FieldAllowUserRefund: - return m.AllowUserRefund() - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.CreatedAt() - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.UpdatedAt() + case pendingauthsession.FieldSessionToken: + return m.SessionToken() + case pendingauthsession.FieldIntent: + return m.Intent() + case pendingauthsession.FieldProviderType: + return m.ProviderType() + case pendingauthsession.FieldProviderKey: + return m.ProviderKey() + case pendingauthsession.FieldProviderSubject: + return m.ProviderSubject() + case pendingauthsession.FieldTargetUserID: + return m.TargetUserID() + case pendingauthsession.FieldRedirectTo: + return m.RedirectTo() + case pendingauthsession.FieldResolvedEmail: + return m.ResolvedEmail() + case pendingauthsession.FieldRegistrationPasswordHash: + return m.RegistrationPasswordHash() + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.UpstreamIdentityClaims() + case pendingauthsession.FieldLocalFlowState: + return m.LocalFlowState() + case pendingauthsession.FieldBrowserSessionKey: + return m.BrowserSessionKey() + case pendingauthsession.FieldCompletionCodeHash: + return m.CompletionCodeHash() + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.CompletionCodeExpiresAt() + case pendingauthsession.FieldEmailVerifiedAt: + return m.EmailVerifiedAt() + case pendingauthsession.FieldPasswordVerifiedAt: + return m.PasswordVerifiedAt() + case pendingauthsession.FieldTotpVerifiedAt: + return m.TotpVerifiedAt() + case pendingauthsession.FieldExpiresAt: + return m.ExpiresAt() + case pendingauthsession.FieldConsumedAt: + return m.ConsumedAt() } return nil, false } @@ -16324,146 +25324,222 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.OldProviderKey(ctx) - case paymentproviderinstance.FieldName: - return m.OldName(ctx) - case paymentproviderinstance.FieldConfig: - return m.OldConfig(ctx) - case paymentproviderinstance.FieldSupportedTypes: - return m.OldSupportedTypes(ctx) - case paymentproviderinstance.FieldEnabled: - return m.OldEnabled(ctx) - case paymentproviderinstance.FieldPaymentMode: - return m.OldPaymentMode(ctx) - case paymentproviderinstance.FieldSortOrder: - return m.OldSortOrder(ctx) - case paymentproviderinstance.FieldLimits: - return m.OldLimits(ctx) - case paymentproviderinstance.FieldRefundEnabled: - return m.OldRefundEnabled(ctx) - case paymentproviderinstance.FieldAllowUserRefund: - return m.OldAllowUserRefund(ctx) - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.OldCreatedAt(ctx) - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case pendingauthsession.FieldSessionToken: + return m.OldSessionToken(ctx) + case pendingauthsession.FieldIntent: + return m.OldIntent(ctx) + case pendingauthsession.FieldProviderType: + return m.OldProviderType(ctx) + case pendingauthsession.FieldProviderKey: + return m.OldProviderKey(ctx) + case pendingauthsession.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case pendingauthsession.FieldTargetUserID: + return m.OldTargetUserID(ctx) + case pendingauthsession.FieldRedirectTo: + return m.OldRedirectTo(ctx) + case pendingauthsession.FieldResolvedEmail: + return m.OldResolvedEmail(ctx) + case pendingauthsession.FieldRegistrationPasswordHash: + return m.OldRegistrationPasswordHash(ctx) + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.OldUpstreamIdentityClaims(ctx) + case pendingauthsession.FieldLocalFlowState: + return m.OldLocalFlowState(ctx) + case pendingauthsession.FieldBrowserSessionKey: + return m.OldBrowserSessionKey(ctx) + case pendingauthsession.FieldCompletionCodeHash: + return m.OldCompletionCodeHash(ctx) + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.OldCompletionCodeExpiresAt(ctx) + case pendingauthsession.FieldEmailVerifiedAt: + return m.OldEmailVerifiedAt(ctx) + case pendingauthsession.FieldPasswordVerifiedAt: + return m.OldPasswordVerifiedAt(ctx) + case pendingauthsession.FieldTotpVerifiedAt: + return m.OldTotpVerifiedAt(ctx) + case pendingauthsession.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case pendingauthsession.FieldConsumedAt: + return m.OldConsumedAt(ctx) } - return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return nil, fmt.Errorf("unknown PendingAuthSession field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case pendingauthsession.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case pendingauthsession.FieldSessionToken: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetProviderKey(v) + m.SetSessionToken(v) return nil - case paymentproviderinstance.FieldName: + case pendingauthsession.FieldIntent: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetName(v) + m.SetIntent(v) return nil - case paymentproviderinstance.FieldConfig: + case pendingauthsession.FieldProviderType: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetConfig(v) + m.SetProviderType(v) return nil - case paymentproviderinstance.FieldSupportedTypes: + case pendingauthsession.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSupportedTypes(v) + m.SetProviderKey(v) return nil - case paymentproviderinstance.FieldEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldProviderSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetEnabled(v) + m.SetProviderSubject(v) return nil - case paymentproviderinstance.FieldPaymentMode: + case pendingauthsession.FieldTargetUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTargetUserID(v) + return nil + case pendingauthsession.FieldRedirectTo: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPaymentMode(v) + m.SetRedirectTo(v) return nil - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) + case pendingauthsession.FieldResolvedEmail: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSortOrder(v) + m.SetResolvedEmail(v) return nil - case paymentproviderinstance.FieldLimits: + case pendingauthsession.FieldRegistrationPasswordHash: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLimits(v) + m.SetRegistrationPasswordHash(v) return nil - case paymentproviderinstance.FieldRefundEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldUpstreamIdentityClaims: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundEnabled(v) + m.SetUpstreamIdentityClaims(v) return nil - case paymentproviderinstance.FieldAllowUserRefund: - v, ok := value.(bool) + case pendingauthsession.FieldLocalFlowState: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAllowUserRefund(v) + m.SetLocalFlowState(v) return nil - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldBrowserSessionKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBrowserSessionKey(v) + return nil + case pendingauthsession.FieldCompletionCodeHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeHash(v) + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetCompletionCodeExpiresAt(v) return nil - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldEmailVerifiedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetEmailVerifiedAt(v) + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPasswordVerifiedAt(v) + return nil + case pendingauthsession.FieldTotpVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpVerifiedAt(v) + return nil + case pendingauthsession.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case pendingauthsession.FieldConsumedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsumedAt(v) return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *PaymentProviderInstanceMutation) AddedFields() []string { +func (m *PendingAuthSessionMutation) AddedFields() []string { var fields []string - if m.addsort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldSortOrder: - return m.AddedSortOrder() } return nil, false } @@ -16471,128 +25547,231 @@ func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bo // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) - return nil } - return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) + return fmt.Errorf("unknown PendingAuthSession numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *PaymentProviderInstanceMutation) ClearedFields() []string { - return nil +func (m *PendingAuthSessionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(pendingauthsession.FieldTargetUserID) { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldConsumedAt) { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { +func (m *PendingAuthSessionMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) +func (m *PendingAuthSessionMutation) ClearField(name string) error { + switch name { + case pendingauthsession.FieldTargetUserID: + m.ClearTargetUserID() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ClearCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ClearEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ClearPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ClearTotpVerifiedAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ClearConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetField(name string) error { +func (m *PendingAuthSessionMutation) ResetField(name string) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case pendingauthsession.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case pendingauthsession.FieldSessionToken: + m.ResetSessionToken() + return nil + case pendingauthsession.FieldIntent: + m.ResetIntent() + return nil + case pendingauthsession.FieldProviderType: + m.ResetProviderType() + return nil + case pendingauthsession.FieldProviderKey: m.ResetProviderKey() return nil - case paymentproviderinstance.FieldName: - m.ResetName() + case pendingauthsession.FieldProviderSubject: + m.ResetProviderSubject() return nil - case paymentproviderinstance.FieldConfig: - m.ResetConfig() + case pendingauthsession.FieldTargetUserID: + m.ResetTargetUserID() return nil - case paymentproviderinstance.FieldSupportedTypes: - m.ResetSupportedTypes() + case pendingauthsession.FieldRedirectTo: + m.ResetRedirectTo() return nil - case paymentproviderinstance.FieldEnabled: - m.ResetEnabled() + case pendingauthsession.FieldResolvedEmail: + m.ResetResolvedEmail() return nil - case paymentproviderinstance.FieldPaymentMode: - m.ResetPaymentMode() + case pendingauthsession.FieldRegistrationPasswordHash: + m.ResetRegistrationPasswordHash() return nil - case paymentproviderinstance.FieldSortOrder: - m.ResetSortOrder() + case pendingauthsession.FieldUpstreamIdentityClaims: + m.ResetUpstreamIdentityClaims() return nil - case paymentproviderinstance.FieldLimits: - m.ResetLimits() + case pendingauthsession.FieldLocalFlowState: + m.ResetLocalFlowState() return nil - case paymentproviderinstance.FieldRefundEnabled: - m.ResetRefundEnabled() + case pendingauthsession.FieldBrowserSessionKey: + m.ResetBrowserSessionKey() return nil - case paymentproviderinstance.FieldAllowUserRefund: - m.ResetAllowUserRefund() + case pendingauthsession.FieldCompletionCodeHash: + m.ResetCompletionCodeHash() return nil - case paymentproviderinstance.FieldCreatedAt: - m.ResetCreatedAt() + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ResetCompletionCodeExpiresAt() return nil - case paymentproviderinstance.FieldUpdatedAt: - m.ResetUpdatedAt() + case pendingauthsession.FieldEmailVerifiedAt: + m.ResetEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ResetPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ResetTotpVerifiedAt() + return nil + case pendingauthsession.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ResetConsumedAt() return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentProviderInstanceMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.target_user != nil { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.adoption_decision != nil { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value { + switch name { + case pendingauthsession.EdgeTargetUser: + if id := m.target_user; id != nil { + return []ent.Value{*id} + } + case pendingauthsession.EdgeAdoptionDecision: + if id := m.adoption_decision; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedtarget_user { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.clearedadoption_decision { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { +func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool { + switch name { + case pendingauthsession.EdgeTargetUser: + return m.clearedtarget_user + case pendingauthsession.EdgeAdoptionDecision: + return m.clearedadoption_decision + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) +func (m *PendingAuthSessionMutation) ClearEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ClearTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ClearAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) +func (m *PendingAuthSessionMutation) ResetEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ResetTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ResetAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession edge %s", name) } // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. @@ -28264,6 +37443,9 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + signup_source *string + last_login_at *time.Time + last_active_at *time.Time balance_notify_enabled *bool balance_notify_threshold_type *string balance_notify_threshold *float64 @@ -28271,6 +37453,8 @@ type UserMutation struct { balance_notify_extra_emails *string total_recharged *float64 addtotal_recharged *float64 + rpm_limit *int + addrpm_limit *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -28302,6 +37486,12 @@ type UserMutation struct { payment_orders map[int64]struct{} removedpayment_orders map[int64]struct{} clearedpayment_orders bool + auth_identities map[int64]struct{} + removedauth_identities map[int64]struct{} + clearedauth_identities bool + pending_auth_sessions map[int64]struct{} + removedpending_auth_sessions map[int64]struct{} + clearedpending_auth_sessions bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -28988,6 +38178,140 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSignupSource sets the "signup_source" field. +func (m *UserMutation) SetSignupSource(s string) { + m.signup_source = &s +} + +// SignupSource returns the value of the "signup_source" field in the mutation. +func (m *UserMutation) SignupSource() (r string, exists bool) { + v := m.signup_source + if v == nil { + return + } + return *v, true +} + +// OldSignupSource returns the old "signup_source" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignupSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignupSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignupSource: %w", err) + } + return oldValue.SignupSource, nil +} + +// ResetSignupSource resets all changes to the "signup_source" field. +func (m *UserMutation) ResetSignupSource() { + m.signup_source = nil +} + +// SetLastLoginAt sets the "last_login_at" field. +func (m *UserMutation) SetLastLoginAt(t time.Time) { + m.last_login_at = &t +} + +// LastLoginAt returns the value of the "last_login_at" field in the mutation. +func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) { + v := m.last_login_at + if v == nil { + return + } + return *v, true +} + +// OldLastLoginAt returns the old "last_login_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLoginAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err) + } + return oldValue.LastLoginAt, nil +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (m *UserMutation) ClearLastLoginAt() { + m.last_login_at = nil + m.clearedFields[user.FieldLastLoginAt] = struct{}{} +} + +// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation. +func (m *UserMutation) LastLoginAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastLoginAt] + return ok +} + +// ResetLastLoginAt resets all changes to the "last_login_at" field. +func (m *UserMutation) ResetLastLoginAt() { + m.last_login_at = nil + delete(m.clearedFields, user.FieldLastLoginAt) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (m *UserMutation) SetLastActiveAt(t time.Time) { + m.last_active_at = &t +} + +// LastActiveAt returns the value of the "last_active_at" field in the mutation. +func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) { + v := m.last_active_at + if v == nil { + return + } + return *v, true +} + +// OldLastActiveAt returns the old "last_active_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActiveAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err) + } + return oldValue.LastActiveAt, nil +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (m *UserMutation) ClearLastActiveAt() { + m.last_active_at = nil + m.clearedFields[user.FieldLastActiveAt] = struct{}{} +} + +// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation. +func (m *UserMutation) LastActiveAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastActiveAt] + return ok +} + +// ResetLastActiveAt resets all changes to the "last_active_at" field. +func (m *UserMutation) ResetLastActiveAt() { + m.last_active_at = nil + delete(m.clearedFields, user.FieldLastActiveAt) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { m.balance_notify_enabled = &b @@ -29222,6 +38546,62 @@ func (m *UserMutation) ResetTotalRecharged() { m.addtotal_recharged = nil } +// SetRpmLimit sets the "rpm_limit" field. +func (m *UserMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *UserMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *UserMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *UserMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *UserMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -29762,6 +39142,114 @@ func (m *UserMutation) ResetPaymentOrders() { m.removedpayment_orders = nil } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids. +func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) { + if m.auth_identities == nil { + m.auth_identities = make(map[int64]struct{}) + } + for i := range ids { + m.auth_identities[ids[i]] = struct{}{} + } +} + +// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) ClearAuthIdentities() { + m.clearedauth_identities = true +} + +// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared. +func (m *UserMutation) AuthIdentitiesCleared() bool { + return m.clearedauth_identities +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs. +func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) { + if m.removedauth_identities == nil { + m.removedauth_identities = make(map[int64]struct{}) + } + for i := range ids { + delete(m.auth_identities, ids[i]) + m.removedauth_identities[ids[i]] = struct{}{} + } +} + +// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) { + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return +} + +// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation. +func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) { + for id := range m.auth_identities { + ids = append(ids, id) + } + return +} + +// ResetAuthIdentities resets all changes to the "auth_identities" edge. +func (m *UserMutation) ResetAuthIdentities() { + m.auth_identities = nil + m.clearedauth_identities = false + m.removedauth_identities = nil +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids. +func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) { + if m.pending_auth_sessions == nil { + m.pending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + m.pending_auth_sessions[ids[i]] = struct{}{} + } +} + +// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) ClearPendingAuthSessions() { + m.clearedpending_auth_sessions = true +} + +// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared. +func (m *UserMutation) PendingAuthSessionsCleared() bool { + return m.clearedpending_auth_sessions +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) { + if m.removedpending_auth_sessions == nil { + m.removedpending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.pending_auth_sessions, ids[i]) + m.removedpending_auth_sessions[ids[i]] = struct{}{} + } +} + +// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) { + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return +} + +// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation. +func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) { + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return +} + +// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge. +func (m *UserMutation) ResetPendingAuthSessions() { + m.pending_auth_sessions = nil + m.clearedpending_auth_sessions = false + m.removedpending_auth_sessions = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -29796,7 +39284,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29839,6 +39327,15 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.signup_source != nil { + fields = append(fields, user.FieldSignupSource) + } + if m.last_login_at != nil { + fields = append(fields, user.FieldLastLoginAt) + } + if m.last_active_at != nil { + fields = append(fields, user.FieldLastActiveAt) + } if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } @@ -29854,6 +39351,9 @@ func (m *UserMutation) Fields() []string { if m.total_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.rpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -29890,6 +39390,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSignupSource: + return m.SignupSource() + case user.FieldLastLoginAt: + return m.LastLoginAt() + case user.FieldLastActiveAt: + return m.LastActiveAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() case user.FieldBalanceNotifyThresholdType: @@ -29900,6 +39406,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.BalanceNotifyExtraEmails() case user.FieldTotalRecharged: return m.TotalRecharged() + case user.FieldRpmLimit: + return m.RpmLimit() } return nil, false } @@ -29937,6 +39445,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSignupSource: + return m.OldSignupSource(ctx) + case user.FieldLastLoginAt: + return m.OldLastLoginAt(ctx) + case user.FieldLastActiveAt: + return m.OldLastActiveAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) case user.FieldBalanceNotifyThresholdType: @@ -29947,6 +39461,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldBalanceNotifyExtraEmails(ctx) case user.FieldTotalRecharged: return m.OldTotalRecharged(ctx) + case user.FieldRpmLimit: + return m.OldRpmLimit(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -30054,6 +39570,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSignupSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignupSource(v) + return nil + case user.FieldLastLoginAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLoginAt(v) + return nil + case user.FieldLastActiveAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActiveAt(v) + return nil case user.FieldBalanceNotifyEnabled: v, ok := value.(bool) if !ok { @@ -30089,6 +39626,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -30109,6 +39653,9 @@ func (m *UserMutation) AddedFields() []string { if m.addtotal_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.addrpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -30125,6 +39672,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalanceNotifyThreshold() case user.FieldTotalRecharged: return m.AddedTotalRecharged() + case user.FieldRpmLimit: + return m.AddedRpmLimit() } return nil, false } @@ -30162,6 +39711,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -30179,6 +39735,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldLastLoginAt) { + fields = append(fields, user.FieldLastLoginAt) + } + if m.FieldCleared(user.FieldLastActiveAt) { + fields = append(fields, user.FieldLastActiveAt) + } if m.FieldCleared(user.FieldBalanceNotifyThreshold) { fields = append(fields, user.FieldBalanceNotifyThreshold) } @@ -30205,6 +39767,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldLastLoginAt: + m.ClearLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ClearLastActiveAt() + return nil case user.FieldBalanceNotifyThreshold: m.ClearBalanceNotifyThreshold() return nil @@ -30258,6 +39826,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSignupSource: + m.ResetSignupSource() + return nil + case user.FieldLastLoginAt: + m.ResetLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ResetLastActiveAt() + return nil case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil @@ -30273,13 +39850,16 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotalRecharged: m.ResetTotalRecharged() return nil + case user.FieldRpmLimit: + m.ResetRpmLimit() + return nil } return fmt.Errorf("unknown User field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30310,6 +39890,12 @@ func (m *UserMutation) AddedEdges() []string { if m.payment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.auth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.pending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30377,13 +39963,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.auth_identities)) + for id := range m.auth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.pending_auth_sessions)) + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30414,6 +40012,12 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpayment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.removedauth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.removedpending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30481,13 +40085,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.removedauth_identities)) + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions)) + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -30518,6 +40134,12 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpayment_orders { edges = append(edges, user.EdgePaymentOrders) } + if m.clearedauth_identities { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.clearedpending_auth_sessions { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30545,6 +40167,10 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedpromo_code_usages case user.EdgePaymentOrders: return m.clearedpayment_orders + case user.EdgeAuthIdentities: + return m.clearedauth_identities + case user.EdgePendingAuthSessions: + return m.clearedpending_auth_sessions } return false } @@ -30591,6 +40217,12 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePaymentOrders: m.ResetPaymentOrders() return nil + case user.EdgeAuthIdentities: + m.ResetAuthIdentities() + return nil + case user.EdgePendingAuthSessions: + m.ResetPendingAuthSessions() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go index 6ea3e70981d1884751b7512e541f53c057c1c206..b131b8c8804575eb74196359508d52db2122391f 100644 --- a/backend/ent/paymentorder.go +++ b/backend/ent/paymentorder.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -56,6 +57,10 @@ type PaymentOrder struct { SubscriptionDays *int `json:"subscription_days,omitempty"` // ProviderInstanceID holds the value of the "provider_instance_id" field. ProviderInstanceID *string `json:"provider_instance_id,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey *string `json:"provider_key,omitempty"` + // ProviderSnapshot holds the value of the "provider_snapshot" field. + ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` // RefundAmount holds the value of the "refund_amount" field. @@ -123,13 +128,15 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case paymentorder.FieldProviderSnapshot: + values[i] = new([]byte) case paymentorder.FieldForceRefund: values[i] = new(sql.NullBool) case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount: values[i] = new(sql.NullFloat64) case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays: values[i] = new(sql.NullInt64) - case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: + case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: values[i] = new(sql.NullString) case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -276,6 +283,21 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error { _m.ProviderInstanceID = new(string) *_m.ProviderInstanceID = value.String } + case paymentorder.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = new(string) + *_m.ProviderKey = value.String + } + case paymentorder.FieldProviderSnapshot: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil { + return fmt.Errorf("unmarshal field provider_snapshot: %w", err) + } + } case paymentorder.FieldStatus: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) @@ -508,6 +530,14 @@ func (_m *PaymentOrder) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ProviderKey; v != nil { + builder.WriteString("provider_key=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("provider_snapshot=") + builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot)) + builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go index 4467b2b635896402c3254245eff2fec0d8fb4136..6288379434280fe4c3cee2a8294ebdf98f686ed2 100644 --- a/backend/ent/paymentorder/paymentorder.go +++ b/backend/ent/paymentorder/paymentorder.go @@ -52,6 +52,10 @@ const ( FieldSubscriptionDays = "subscription_days" // FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database. FieldProviderInstanceID = "provider_instance_id" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database. + FieldProviderSnapshot = "provider_snapshot" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" // FieldRefundAmount holds the string denoting the refund_amount field in the database. @@ -123,6 +127,8 @@ var Columns = []string{ FieldSubscriptionGroupID, FieldSubscriptionDays, FieldProviderInstanceID, + FieldProviderKey, + FieldProviderSnapshot, FieldStatus, FieldRefundAmount, FieldRefundReason, @@ -176,6 +182,8 @@ var ( OrderTypeValidator func(string) error // ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. ProviderInstanceIDValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error // DefaultStatus holds the default value on creation for the "status" field. DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. @@ -301,6 +309,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc() } +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + // ByStatus orders the results by the status field. func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go index 78520fac4286fd7281262789aa3154697f1e3951..e96bf51ebd09499c5f478fa88720e669f3b02894 100644 --- a/backend/ent/paymentorder/where.go +++ b/backend/ent/paymentorder/where.go @@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v)) } +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. func Status(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) @@ -1360,6 +1365,91 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v)) } +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field. +func ProviderKeyIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey)) +} + +// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field. +func ProviderKeyNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field. +func ProviderSnapshotIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot)) +} + +// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field. +func ProviderSnapshotNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot)) +} + // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go index 030983390124b2840304ae5f021b6634ed3abdba..3ee24f8e918dbd18a1f3aad0c805bf870af6eebd 100644 --- a/backend/ent/paymentorder_create.go +++ b/backend/ent/paymentorder_create.go @@ -225,6 +225,26 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO return _c } +// SetProviderKey sets the "provider_key" field. +func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate { + if v != nil { + _c.SetProviderKey(*v) + } + return _c +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate { + _c.mutation.SetProviderSnapshot(v) + return _c +} + // SetStatus sets the "status" field. func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate { _c.mutation.SetStatus(v) @@ -602,6 +622,11 @@ func (_c *PaymentOrderCreate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if _, ok := _c.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)} } @@ -748,6 +773,14 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec) _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value) _node.ProviderInstanceID = &value } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = &value + } + if value, ok := _c.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + _node.ProviderSnapshot = value + } if value, ok := _c.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) _node.Status = value @@ -1201,6 +1234,42 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert { return u } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderKey) + return u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderKey) + return u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderSnapshot, v) + return u +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderSnapshot) + return u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderSnapshot) + return u +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert { u.Set(paymentorder.FieldStatus, v) @@ -1880,6 +1949,48 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderSnapshot(v) + }) +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderSnapshot() + }) +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderSnapshot() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne { return u.Update(func(s *PaymentOrderUpsert) { @@ -2770,6 +2881,48 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderSnapshot(v) + }) +} + +// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderSnapshot() + }) +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderSnapshot() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk { return u.Update(func(s *PaymentOrderUpsert) { diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go index 5978fc29148618f828e6586f8baf59d24c64ed1f..378e0dad2f90f2233387fd8bb02bf7018a09ed69 100644 --- a/backend/ent/paymentorder_update.go +++ b/backend/ent/paymentorder_update.go @@ -385,6 +385,38 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate { return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate { + _u.mutation.ClearProviderKey() + return _u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate { + _u.mutation.SetProviderSnapshot(v) + return _u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate { + _u.mutation.ClearProviderSnapshot() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate { _u.mutation.SetStatus(v) @@ -776,6 +808,11 @@ func (_u *PaymentOrderUpdate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -910,6 +947,18 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } + if value, ok := _u.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + } + if _u.mutation.ProviderSnapshotCleared() { + _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } @@ -1399,6 +1448,38 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderKey() + return _u +} + +// SetProviderSnapshot sets the "provider_snapshot" field. +func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne { + _u.mutation.SetProviderSnapshot(v) + return _u +} + +// ClearProviderSnapshot clears the value of the "provider_snapshot" field. +func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderSnapshot() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne { _u.mutation.SetStatus(v) @@ -1803,6 +1884,11 @@ func (_u *PaymentOrderUpdateOne) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -1954,6 +2040,18 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } + if value, ok := _u.mutation.ProviderSnapshot(); ok { + _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value) + } + if _u.mutation.ProviderSnapshotCleared() { + _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go new file mode 100644 index 0000000000000000000000000000000000000000..e77c065f779add6dc6dd6cbf860bfda6dfe418ba --- /dev/null +++ b/backend/ent/pendingauthsession.go @@ -0,0 +1,399 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSession is the model entity for the PendingAuthSession schema. +type PendingAuthSession struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken string `json:"session_token,omitempty"` + // Intent holds the value of the "intent" field. + Intent string `json:"intent,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // TargetUserID holds the value of the "target_user_id" field. + TargetUserID *int64 `json:"target_user_id,omitempty"` + // RedirectTo holds the value of the "redirect_to" field. + RedirectTo string `json:"redirect_to,omitempty"` + // ResolvedEmail holds the value of the "resolved_email" field. + ResolvedEmail string `json:"resolved_email,omitempty"` + // RegistrationPasswordHash holds the value of the "registration_password_hash" field. + RegistrationPasswordHash string `json:"registration_password_hash,omitempty"` + // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field. + UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"` + // LocalFlowState holds the value of the "local_flow_state" field. + LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"` + // BrowserSessionKey holds the value of the "browser_session_key" field. + BrowserSessionKey string `json:"browser_session_key,omitempty"` + // CompletionCodeHash holds the value of the "completion_code_hash" field. + CompletionCodeHash string `json:"completion_code_hash,omitempty"` + // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field. + CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"` + // EmailVerifiedAt holds the value of the "email_verified_at" field. + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + // PasswordVerifiedAt holds the value of the "password_verified_at" field. + PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"` + // TotpVerifiedAt holds the value of the "totp_verified_at" field. + TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // ConsumedAt holds the value of the "consumed_at" field. + ConsumedAt *time.Time `json:"consumed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PendingAuthSessionQuery when eager-loading is set. + Edges PendingAuthSessionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph. +type PendingAuthSessionEdges struct { + // TargetUser holds the value of the target_user edge. + TargetUser *User `json:"target_user,omitempty"` + // AdoptionDecision holds the value of the adoption_decision edge. + AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// TargetUserOrErr returns the TargetUser value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) { + if e.TargetUser != nil { + return e.TargetUser, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "target_user"} +} + +// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) { + if e.AdoptionDecision != nil { + return e.AdoptionDecision, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: identityadoptiondecision.Label} + } + return nil, &NotLoadedError{edge: "adoption_decision"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*PendingAuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState: + values[i] = new([]byte) + case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID: + values[i] = new(sql.NullInt64) + case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash: + values[i] = new(sql.NullString) + case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the PendingAuthSession fields. +func (_m *PendingAuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case pendingauthsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case pendingauthsession.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case pendingauthsession.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = value.String + } + case pendingauthsession.FieldIntent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field intent", values[i]) + } else if value.Valid { + _m.Intent = value.String + } + case pendingauthsession.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case pendingauthsession.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case pendingauthsession.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case pendingauthsession.FieldTargetUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field target_user_id", values[i]) + } else if value.Valid { + _m.TargetUserID = new(int64) + *_m.TargetUserID = value.Int64 + } + case pendingauthsession.FieldRedirectTo: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field redirect_to", values[i]) + } else if value.Valid { + _m.RedirectTo = value.String + } + case pendingauthsession.FieldResolvedEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resolved_email", values[i]) + } else if value.Valid { + _m.ResolvedEmail = value.String + } + case pendingauthsession.FieldRegistrationPasswordHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i]) + } else if value.Valid { + _m.RegistrationPasswordHash = value.String + } + case pendingauthsession.FieldUpstreamIdentityClaims: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil { + return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err) + } + } + case pendingauthsession.FieldLocalFlowState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field local_flow_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil { + return fmt.Errorf("unmarshal field local_flow_state: %w", err) + } + } + case pendingauthsession.FieldBrowserSessionKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field browser_session_key", values[i]) + } else if value.Valid { + _m.BrowserSessionKey = value.String + } + case pendingauthsession.FieldCompletionCodeHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i]) + } else if value.Valid { + _m.CompletionCodeHash = value.String + } + case pendingauthsession.FieldCompletionCodeExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i]) + } else if value.Valid { + _m.CompletionCodeExpiresAt = new(time.Time) + *_m.CompletionCodeExpiresAt = value.Time + } + case pendingauthsession.FieldEmailVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field email_verified_at", values[i]) + } else if value.Valid { + _m.EmailVerifiedAt = new(time.Time) + *_m.EmailVerifiedAt = value.Time + } + case pendingauthsession.FieldPasswordVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field password_verified_at", values[i]) + } else if value.Valid { + _m.PasswordVerifiedAt = new(time.Time) + *_m.PasswordVerifiedAt = value.Time + } + case pendingauthsession.FieldTotpVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i]) + } else if value.Valid { + _m.TotpVerifiedAt = new(time.Time) + *_m.TotpVerifiedAt = value.Time + } + case pendingauthsession.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + case pendingauthsession.FieldConsumedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field consumed_at", values[i]) + } else if value.Valid { + _m.ConsumedAt = new(time.Time) + *_m.ConsumedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *PendingAuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryTargetUser() *UserQuery { + return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m) +} + +// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m) +} + +// Update returns a builder for updating this PendingAuthSession. +// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne { + return NewPendingAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *PendingAuthSession) Unwrap() *PendingAuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: PendingAuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *PendingAuthSession) String() string { + var builder strings.Builder + builder.WriteString("PendingAuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("session_token=") + builder.WriteString(_m.SessionToken) + builder.WriteString(", ") + builder.WriteString("intent=") + builder.WriteString(_m.Intent) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.TargetUserID; v != nil { + builder.WriteString("target_user_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("redirect_to=") + builder.WriteString(_m.RedirectTo) + builder.WriteString(", ") + builder.WriteString("resolved_email=") + builder.WriteString(_m.ResolvedEmail) + builder.WriteString(", ") + builder.WriteString("registration_password_hash=") + builder.WriteString(_m.RegistrationPasswordHash) + builder.WriteString(", ") + builder.WriteString("upstream_identity_claims=") + builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims)) + builder.WriteString(", ") + builder.WriteString("local_flow_state=") + builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState)) + builder.WriteString(", ") + builder.WriteString("browser_session_key=") + builder.WriteString(_m.BrowserSessionKey) + builder.WriteString(", ") + builder.WriteString("completion_code_hash=") + builder.WriteString(_m.CompletionCodeHash) + builder.WriteString(", ") + if v := _m.CompletionCodeExpiresAt; v != nil { + builder.WriteString("completion_code_expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.EmailVerifiedAt; v != nil { + builder.WriteString("email_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.PasswordVerifiedAt; v != nil { + builder.WriteString("password_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TotpVerifiedAt; v != nil { + builder.WriteString("totp_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConsumedAt; v != nil { + builder.WriteString("consumed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// PendingAuthSessions is a parsable slice of PendingAuthSession. +type PendingAuthSessions []*PendingAuthSession diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go new file mode 100644 index 0000000000000000000000000000000000000000..8a3ac9bf783f191c796ce78f71c6d89130ae3c1c --- /dev/null +++ b/backend/ent/pendingauthsession/pendingauthsession.go @@ -0,0 +1,279 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the pendingauthsession type in the database. + Label = "pending_auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldIntent holds the string denoting the intent field in the database. + FieldIntent = "intent" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldTargetUserID holds the string denoting the target_user_id field in the database. + FieldTargetUserID = "target_user_id" + // FieldRedirectTo holds the string denoting the redirect_to field in the database. + FieldRedirectTo = "redirect_to" + // FieldResolvedEmail holds the string denoting the resolved_email field in the database. + FieldResolvedEmail = "resolved_email" + // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database. + FieldRegistrationPasswordHash = "registration_password_hash" + // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database. + FieldUpstreamIdentityClaims = "upstream_identity_claims" + // FieldLocalFlowState holds the string denoting the local_flow_state field in the database. + FieldLocalFlowState = "local_flow_state" + // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database. + FieldBrowserSessionKey = "browser_session_key" + // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database. + FieldCompletionCodeHash = "completion_code_hash" + // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database. + FieldCompletionCodeExpiresAt = "completion_code_expires_at" + // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database. + FieldEmailVerifiedAt = "email_verified_at" + // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database. + FieldPasswordVerifiedAt = "password_verified_at" + // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database. + FieldTotpVerifiedAt = "totp_verified_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldConsumedAt holds the string denoting the consumed_at field in the database. + FieldConsumedAt = "consumed_at" + // EdgeTargetUser holds the string denoting the target_user edge name in mutations. + EdgeTargetUser = "target_user" + // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations. + EdgeAdoptionDecision = "adoption_decision" + // Table holds the table name of the pendingauthsession in the database. + Table = "pending_auth_sessions" + // TargetUserTable is the table that holds the target_user relation/edge. + TargetUserTable = "pending_auth_sessions" + // TargetUserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + TargetUserInverseTable = "users" + // TargetUserColumn is the table column denoting the target_user relation/edge. + TargetUserColumn = "target_user_id" + // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge. + AdoptionDecisionTable = "identity_adoption_decisions" + // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionInverseTable = "identity_adoption_decisions" + // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge. + AdoptionDecisionColumn = "pending_auth_session_id" +) + +// Columns holds all SQL columns for pendingauthsession fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldSessionToken, + FieldIntent, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldTargetUserID, + FieldRedirectTo, + FieldResolvedEmail, + FieldRegistrationPasswordHash, + FieldUpstreamIdentityClaims, + FieldLocalFlowState, + FieldBrowserSessionKey, + FieldCompletionCodeHash, + FieldCompletionCodeExpiresAt, + FieldEmailVerifiedAt, + FieldPasswordVerifiedAt, + FieldTotpVerifiedAt, + FieldExpiresAt, + FieldConsumedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + SessionTokenValidator func(string) error + // IntentValidator is a validator for the "intent" field. It is called by the builders before save. + IntentValidator func(string) error + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultRedirectTo holds the default value on creation for the "redirect_to" field. + DefaultRedirectTo string + // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field. + DefaultResolvedEmail string + // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field. + DefaultRegistrationPasswordHash string + // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field. + DefaultUpstreamIdentityClaims func() map[string]interface{} + // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field. + DefaultLocalFlowState func() map[string]interface{} + // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field. + DefaultBrowserSessionKey string + // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field. + DefaultCompletionCodeHash string +) + +// OrderOption defines the ordering options for the PendingAuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByIntent orders the results by the intent field. +func ByIntent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntent, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByTargetUserID orders the results by the target_user_id field. +func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTargetUserID, opts...).ToFunc() +} + +// ByRedirectTo orders the results by the redirect_to field. +func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRedirectTo, opts...).ToFunc() +} + +// ByResolvedEmail orders the results by the resolved_email field. +func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc() +} + +// ByRegistrationPasswordHash orders the results by the registration_password_hash field. +func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc() +} + +// ByBrowserSessionKey orders the results by the browser_session_key field. +func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc() +} + +// ByCompletionCodeHash orders the results by the completion_code_hash field. +func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc() +} + +// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field. +func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc() +} + +// ByEmailVerifiedAt orders the results by the email_verified_at field. +func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc() +} + +// ByPasswordVerifiedAt orders the results by the password_verified_at field. +func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc() +} + +// ByTotpVerifiedAt orders the results by the totp_verified_at field. +func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByConsumedAt orders the results by the consumed_at field. +func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsumedAt, opts...).ToFunc() +} + +// ByTargetUserField orders the results by target_user field. +func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAdoptionDecisionField orders the results by adoption_decision field. +func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...)) + } +} +func newTargetUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TargetUserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) +} +func newAdoptionDecisionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) +} diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go new file mode 100644 index 0000000000000000000000000000000000000000..cb316f476e44195e74f961699d058837cbe38630 --- /dev/null +++ b/backend/ent/pendingauthsession/where.go @@ -0,0 +1,1262 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ. +func Intent(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ. +func TargetUserID(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ. +func RedirectTo(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ. +func ResolvedEmail(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ. +func RegistrationPasswordHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ. +func BrowserSessionKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ. +func CompletionCodeHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ. +func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ. +func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ. +func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ. +func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ. +func ConsumedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// IntentEQ applies the EQ predicate on the "intent" field. +func IntentEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// IntentNEQ applies the NEQ predicate on the "intent" field. +func IntentNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v)) +} + +// IntentIn applies the In predicate on the "intent" field. +func IntentIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...)) +} + +// IntentNotIn applies the NotIn predicate on the "intent" field. +func IntentNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...)) +} + +// IntentGT applies the GT predicate on the "intent" field. +func IntentGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v)) +} + +// IntentGTE applies the GTE predicate on the "intent" field. +func IntentGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v)) +} + +// IntentLT applies the LT predicate on the "intent" field. +func IntentLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v)) +} + +// IntentLTE applies the LTE predicate on the "intent" field. +func IntentLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v)) +} + +// IntentContains applies the Contains predicate on the "intent" field. +func IntentContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v)) +} + +// IntentHasPrefix applies the HasPrefix predicate on the "intent" field. +func IntentHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v)) +} + +// IntentHasSuffix applies the HasSuffix predicate on the "intent" field. +func IntentHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v)) +} + +// IntentEqualFold applies the EqualFold predicate on the "intent" field. +func IntentEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v)) +} + +// IntentContainsFold applies the ContainsFold predicate on the "intent" field. +func IntentContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field. +func TargetUserIDEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field. +func TargetUserIDNEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v)) +} + +// TargetUserIDIn applies the In predicate on the "target_user_id" field. +func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field. +func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field. +func TargetUserIDIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID)) +} + +// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field. +func TargetUserIDNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID)) +} + +// RedirectToEQ applies the EQ predicate on the "redirect_to" field. +func RedirectToEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field. +func RedirectToNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v)) +} + +// RedirectToIn applies the In predicate on the "redirect_to" field. +func RedirectToIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...)) +} + +// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field. +func RedirectToNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...)) +} + +// RedirectToGT applies the GT predicate on the "redirect_to" field. +func RedirectToGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v)) +} + +// RedirectToGTE applies the GTE predicate on the "redirect_to" field. +func RedirectToGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v)) +} + +// RedirectToLT applies the LT predicate on the "redirect_to" field. +func RedirectToLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v)) +} + +// RedirectToLTE applies the LTE predicate on the "redirect_to" field. +func RedirectToLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v)) +} + +// RedirectToContains applies the Contains predicate on the "redirect_to" field. +func RedirectToContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v)) +} + +// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field. +func RedirectToHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v)) +} + +// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field. +func RedirectToHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v)) +} + +// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field. +func RedirectToEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v)) +} + +// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field. +func RedirectToContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v)) +} + +// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field. +func ResolvedEmailEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field. +func ResolvedEmailNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailIn applies the In predicate on the "resolved_email" field. +func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field. +func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailGT applies the GT predicate on the "resolved_email" field. +func ResolvedEmailGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v)) +} + +// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field. +func ResolvedEmailGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailLT applies the LT predicate on the "resolved_email" field. +func ResolvedEmailLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v)) +} + +// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field. +func ResolvedEmailLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field. +func ResolvedEmailContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field. +func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field. +func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v)) +} + +// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field. +func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v)) +} + +// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field. +func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field. +func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field. +func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field. +func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field. +func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field. +func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field. +func BrowserSessionKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field. +func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field. +func BrowserSessionKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field. +func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field. +func BrowserSessionKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field. +func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field. +func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field. +func CompletionCodeHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field. +func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field. +func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field. +func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field. +func CompletionCodeHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field. +func CompletionCodeHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field. +func CompletionCodeHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field. +func CompletionCodeHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field. +func CompletionCodeHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field. +func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field. +func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt)) +} + +// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt)) +} + +// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field. +func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field. +func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field. +func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field. +func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field. +func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field. +func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field. +func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field. +func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field. +func EmailVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt)) +} + +// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field. +func EmailVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt)) +} + +// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field. +func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field. +func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field. +func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt)) +} + +// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt)) +} + +// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field. +func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field. +func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field. +func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt)) +} + +// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field. +func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field. +func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v)) +} + +// ConsumedAtIn applies the In predicate on the "consumed_at" field. +func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field. +func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtGT applies the GT predicate on the "consumed_at" field. +func ConsumedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v)) +} + +// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field. +func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v)) +} + +// ConsumedAtLT applies the LT predicate on the "consumed_at" field. +func ConsumedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v)) +} + +// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field. +func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v)) +} + +// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field. +func ConsumedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt)) +} + +// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field. +func ConsumedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt)) +} + +// HasTargetUser applies the HasEdge predicate on the "target_user" edge. +func HasTargetUser() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates). +func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newTargetUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge. +func HasAdoptionDecision() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates). +func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newAdoptionDecisionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.NotPredicates(p)) +} diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go new file mode 100644 index 0000000000000000000000000000000000000000..60276daa1bd9a1913f8fb65b0ff515471dc48210 --- /dev/null +++ b/backend/ent/pendingauthsession_create.go @@ -0,0 +1,1815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity. +type PendingAuthSessionCreate struct { + config + mutation *PendingAuthSessionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetIntent sets the "intent" field. +func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate { + _c.mutation.SetIntent(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetTargetUserID sets the "target_user_id" field. +func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate { + _c.mutation.SetTargetUserID(v) + return _c +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate { + if v != nil { + _c.SetTargetUserID(*v) + } + return _c +} + +// SetRedirectTo sets the "redirect_to" field. +func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate { + _c.mutation.SetRedirectTo(v) + return _c +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRedirectTo(*v) + } + return _c +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate { + _c.mutation.SetResolvedEmail(v) + return _c +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetResolvedEmail(*v) + } + return _c +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetRegistrationPasswordHash(v) + return _c +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRegistrationPasswordHash(*v) + } + return _c +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetUpstreamIdentityClaims(v) + return _c +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetLocalFlowState(v) + return _c +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetBrowserSessionKey(v) + return _c +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetBrowserSessionKey(*v) + } + return _c +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeHash(v) + return _c +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeHash(*v) + } + return _c +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeExpiresAt(v) + return _c +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeExpiresAt(*v) + } + return _c +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetEmailVerifiedAt(v) + return _c +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetEmailVerifiedAt(*v) + } + return _c +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetPasswordVerifiedAt(v) + return _c +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetPasswordVerifiedAt(*v) + } + return _c +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetTotpVerifiedAt(v) + return _c +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetTotpVerifiedAt(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetConsumedAt sets the "consumed_at" field. +func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetConsumedAt(v) + return _c +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetConsumedAt(*v) + } + return _c +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate { + return _c.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate { + _c.mutation.SetAdoptionDecisionID(id) + return _c +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate { + if id != nil { + _c = _c.SetAdoptionDecisionID(*id) + } + return _c +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate { + return _c.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation { + return _c.mutation +} + +// Save creates the PendingAuthSession in the database. +func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *PendingAuthSessionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := pendingauthsession.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := pendingauthsession.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.RedirectTo(); !ok { + v := pendingauthsession.DefaultRedirectTo + _c.mutation.SetRedirectTo(v) + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + v := pendingauthsession.DefaultResolvedEmail + _c.mutation.SetResolvedEmail(v) + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + v := pendingauthsession.DefaultRegistrationPasswordHash + _c.mutation.SetRegistrationPasswordHash(v) + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + v := pendingauthsession.DefaultUpstreamIdentityClaims() + _c.mutation.SetUpstreamIdentityClaims(v) + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + v := pendingauthsession.DefaultLocalFlowState() + _c.mutation.SetLocalFlowState(v) + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + v := pendingauthsession.DefaultBrowserSessionKey + _c.mutation.SetBrowserSessionKey(v) + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + v := pendingauthsession.DefaultCompletionCodeHash + _c.mutation.SetCompletionCodeHash(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *PendingAuthSessionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)} + } + if _, ok := _c.mutation.SessionToken(); !ok { + return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)} + } + if v, ok := _c.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if _, ok := _c.mutation.Intent(); !ok { + return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)} + } + if v, ok := _c.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.RedirectTo(); !ok { + return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)} + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)} + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)} + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)} + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)} + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)} + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)} + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)} + } + return nil +} + +func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &PendingAuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + _node.SessionToken = value + } + if value, ok := _c.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + _node.Intent = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + _node.RedirectTo = value + } + if value, ok := _c.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + _node.ResolvedEmail = value + } + if value, ok := _c.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + _node.RegistrationPasswordHash = value + } + if value, ok := _c.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + _node.UpstreamIdentityClaims = value + } + if value, ok := _c.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + _node.LocalFlowState = value + } + if value, ok := _c.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + _node.BrowserSessionKey = value + } + if value, ok := _c.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + _node.CompletionCodeHash = value + } + if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + _node.CompletionCodeExpiresAt = &value + } + if value, ok := _c.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + _node.EmailVerifiedAt = &value + } + if value, ok := _c.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + _node.PasswordVerifiedAt = &value + } + if value, ok := _c.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + _node.TotpVerifiedAt = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := _c.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + _node.ConsumedAt = &value + } + if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TargetUserID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne { + _c.conflict = opts + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +type ( + // PendingAuthSessionUpsertOne is the builder for "upsert"-ing + // one PendingAuthSession node. + PendingAuthSessionUpsertOne struct { + create *PendingAuthSessionCreate + } + + // PendingAuthSessionUpsert is the "OnConflict" setter. + PendingAuthSessionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpdatedAt) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldSessionToken) + return u +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldIntent, v) + return u +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldIntent) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderSubject) + return u +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTargetUserID, v) + return u +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTargetUserID) + return u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTargetUserID) + return u +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRedirectTo, v) + return u +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRedirectTo) + return u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldResolvedEmail, v) + return u +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldResolvedEmail) + return u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRegistrationPasswordHash, v) + return u +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash) + return u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v) + return u +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims) + return u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldLocalFlowState, v) + return u +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldLocalFlowState) + return u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldBrowserSessionKey, v) + return u +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldBrowserSessionKey) + return u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeHash, v) + return u +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeHash) + return u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v) + return u +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldEmailVerifiedAt, v) + return u +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldPasswordVerifiedAt, v) + return u +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTotpVerifiedAt, v) + return u +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldExpiresAt) + return u +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldConsumedAt, v) + return u +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldConsumedAt) + return u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldConsumedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk. +type PendingAuthSessionCreateBulk struct { + config + err error + builders []*PendingAuthSessionCreate + conflict []sql.ConflictOption +} + +// Save creates the PendingAuthSession entities in the database. +func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*PendingAuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PendingAuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk { + _c.conflict = opts + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing +// a bulk of PendingAuthSession nodes. +type PendingAuthSessionUpsertBulk struct { + create *PendingAuthSessionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go new file mode 100644 index 0000000000000000000000000000000000000000..ee4fe6051d11812bdced2f78b4cc9554828910fa --- /dev/null +++ b/backend/ent/pendingauthsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity. +type PendingAuthSessionDelete struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity. +type PendingAuthSessionDeleteOne struct { + _d *PendingAuthSessionDelete +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{pendingauthsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go new file mode 100644 index 0000000000000000000000000000000000000000..78e29cd2bedf07258e955a0ca95420c5e3da9e3e --- /dev/null +++ b/backend/ent/pendingauthsession_query.go @@ -0,0 +1,717 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities. +type PendingAuthSessionQuery struct { + config + ctx *QueryContext + order []pendingauthsession.OrderOption + inters []Interceptor + predicates []predicate.PendingAuthSession + withTargetUser *UserQuery + withAdoptionDecision *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PendingAuthSessionQuery builder. +func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryTargetUser chains the current query on the "target_user" edge. +func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecision chains the current query on the "adoption_decision" edge. +func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first PendingAuthSession entity from the query. +// Returns a *NotFoundError when no PendingAuthSession was found. +func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{pendingauthsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first PendingAuthSession ID from the query. +// Returns a *NotFoundError when no PendingAuthSession ID was found. +func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{pendingauthsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one PendingAuthSession entity is found. +// Returns a *NotFoundError when no PendingAuthSession entities are found. +func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{pendingauthsession.Label} + default: + return nil, &NotSingularError{pendingauthsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only PendingAuthSession ID in the query. +// Returns a *NotSingularError when more than one PendingAuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{pendingauthsession.Label} + default: + err = &NotSingularError{pendingauthsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of PendingAuthSessions. +func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]() + return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of PendingAuthSession IDs. +func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery { + if _q == nil { + return nil + } + return &PendingAuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]pendingauthsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.PendingAuthSession{}, _q.predicates...), + withTargetUser: _q.withTargetUser.Clone(), + withAdoptionDecision: _q.withAdoptionDecision.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithTargetUser tells the query-builder to eager-load the nodes that are connected to +// the "target_user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withTargetUser = query + return _q +} + +// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecision = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// GroupBy(pendingauthsession.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &PendingAuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = pendingauthsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// Select(pendingauthsession.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q} + sbuild.label = pendingauthsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations. +func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !pendingauthsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) { + var ( + nodes = []*PendingAuthSession{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withTargetUser != nil, + _q.withAdoptionDecision != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*PendingAuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &PendingAuthSession{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withTargetUser; query != nil { + if err := _q.loadTargetUser(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecision; query != nil { + if err := _q.loadAdoptionDecision(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*PendingAuthSession) + for i := range nodes { + if nodes[i].TargetUserID == nil { + continue + } + fk := *nodes[i].TargetUserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*PendingAuthSession) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PendingAuthSessionID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for i := range fields { + if fields[i] != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withTargetUser != nil { + _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(pendingauthsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = pendingauthsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities. +type PendingAuthSessionGroupBy struct { + selector + build *PendingAuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities. +type PendingAuthSessionSelect struct { + *PendingAuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v) +} + +func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go new file mode 100644 index 0000000000000000000000000000000000000000..00066f699baf7a7fe87a01f4870035ba87c4431e --- /dev/null +++ b/backend/ent/pendingauthsession_update.go @@ -0,0 +1,1178 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities. +type PendingAuthSessionUpdate struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdate) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity. +type PendingAuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated PendingAuthSession entity. +func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdateOne) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for _, f := range fields { + if !pendingauthsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &PendingAuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940067ec1a2d96783ca78994bc6cb8d9fae..dc86471e7931b322d445b6cccdf938d805a95043 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,24 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// AuthIdentity is the predicate function for authidentity builders. +type AuthIdentity func(*sql.Selector) + +// AuthIdentityChannel is the predicate function for authidentitychannel builders. +type AuthIdentityChannel func(*sql.Selector) + +// ChannelMonitor is the predicate function for channelmonitor builders. +type ChannelMonitor func(*sql.Selector) + +// ChannelMonitorDailyRollup is the predicate function for channelmonitordailyrollup builders. +type ChannelMonitorDailyRollup func(*sql.Selector) + +// ChannelMonitorHistory is the predicate function for channelmonitorhistory builders. +type ChannelMonitorHistory func(*sql.Selector) + +// ChannelMonitorRequestTemplate is the predicate function for channelmonitorrequesttemplate builders. +type ChannelMonitorRequestTemplate func(*sql.Selector) + // ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. type ErrorPassthroughRule func(*sql.Selector) @@ -30,6 +48,9 @@ type Group func(*sql.Selector) // IdempotencyRecord is the predicate function for idempotencyrecord builders. type IdempotencyRecord func(*sql.Selector) +// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders. +type IdentityAdoptionDecision func(*sql.Selector) + // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) @@ -39,6 +60,9 @@ type PaymentOrder func(*sql.Selector) // PaymentProviderInstance is the predicate function for paymentproviderinstance builders. type PaymentProviderInstance func(*sql.Selector) +// PendingAuthSession is the predicate function for pendingauthsession builders. +type PendingAuthSession func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fbdd08c785c5a3a81ebafe1d27ddf488f9bae17c..6b344a5582c6cd3a05eaa69a0c3a57eb3f067999 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,12 +10,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -309,6 +317,366 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + authidentityMixin := schema.AuthIdentity{}.Mixin() + authidentityMixinFields0 := authidentityMixin[0].Fields() + _ = authidentityMixinFields0 + authidentityFields := schema.AuthIdentity{}.Fields() + _ = authidentityFields + // authidentityDescCreatedAt is the schema descriptor for created_at field. + authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor() + // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time) + // authidentityDescUpdatedAt is the schema descriptor for updated_at field. + authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor() + // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time) + // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentityDescProviderType is the schema descriptor for provider_type field. + authidentityDescProviderType := authidentityFields[1].Descriptor() + // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentity.ProviderTypeValidator = func() func(string) error { + validators := authidentityDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentityDescProviderKey is the schema descriptor for provider_key field. + authidentityDescProviderKey := authidentityFields[2].Descriptor() + // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error) + // authidentityDescProviderSubject is the schema descriptor for provider_subject field. + authidentityDescProviderSubject := authidentityFields[3].Descriptor() + // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error) + // authidentityDescMetadata is the schema descriptor for metadata field. + authidentityDescMetadata := authidentityFields[6].Descriptor() + // authidentity.DefaultMetadata holds the default value on creation for the metadata field. + authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{}) + authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin() + authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields() + _ = authidentitychannelMixinFields0 + authidentitychannelFields := schema.AuthIdentityChannel{}.Fields() + _ = authidentitychannelFields + // authidentitychannelDescCreatedAt is the schema descriptor for created_at field. + authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor() + // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time) + // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field. + authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor() + // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time) + // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentitychannelDescProviderType is the schema descriptor for provider_type field. + authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor() + // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentitychannel.ProviderTypeValidator = func() func(string) error { + validators := authidentitychannelDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescProviderKey is the schema descriptor for provider_key field. + authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor() + // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error) + // authidentitychannelDescChannel is the schema descriptor for channel field. + authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor() + // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + authidentitychannel.ChannelValidator = func() func(string) error { + validators := authidentitychannelDescChannel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(channel string) error { + for _, fn := range fns { + if err := fn(channel); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field. + authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor() + // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error) + // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field. + authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor() + // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error) + // authidentitychannelDescMetadata is the schema descriptor for metadata field. + authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor() + // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field. + authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{}) + channelmonitorMixin := schema.ChannelMonitor{}.Mixin() + channelmonitorMixinFields0 := channelmonitorMixin[0].Fields() + _ = channelmonitorMixinFields0 + channelmonitorFields := schema.ChannelMonitor{}.Fields() + _ = channelmonitorFields + // channelmonitorDescCreatedAt is the schema descriptor for created_at field. + channelmonitorDescCreatedAt := channelmonitorMixinFields0[0].Descriptor() + // channelmonitor.DefaultCreatedAt holds the default value on creation for the created_at field. + channelmonitor.DefaultCreatedAt = channelmonitorDescCreatedAt.Default.(func() time.Time) + // channelmonitorDescUpdatedAt is the schema descriptor for updated_at field. + channelmonitorDescUpdatedAt := channelmonitorMixinFields0[1].Descriptor() + // channelmonitor.DefaultUpdatedAt holds the default value on creation for the updated_at field. + channelmonitor.DefaultUpdatedAt = channelmonitorDescUpdatedAt.Default.(func() time.Time) + // channelmonitor.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + channelmonitor.UpdateDefaultUpdatedAt = channelmonitorDescUpdatedAt.UpdateDefault.(func() time.Time) + // channelmonitorDescName is the schema descriptor for name field. + channelmonitorDescName := channelmonitorFields[0].Descriptor() + // channelmonitor.NameValidator is a validator for the "name" field. It is called by the builders before save. + channelmonitor.NameValidator = func() func(string) error { + validators := channelmonitorDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // channelmonitorDescEndpoint is the schema descriptor for endpoint field. + channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor() + // channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save. + channelmonitor.EndpointValidator = func() func(string) error { + validators := channelmonitorDescEndpoint.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(endpoint string) error { + for _, fn := range fns { + if err := fn(endpoint); err != nil { + return err + } + } + return nil + } + }() + // channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field. + channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor() + // channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save. + channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error) + // channelmonitorDescPrimaryModel is the schema descriptor for primary_model field. + channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor() + // channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save. + channelmonitor.PrimaryModelValidator = func() func(string) error { + validators := channelmonitorDescPrimaryModel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(primary_model string) error { + for _, fn := range fns { + if err := fn(primary_model); err != nil { + return err + } + } + return nil + } + }() + // channelmonitorDescExtraModels is the schema descriptor for extra_models field. + channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor() + // channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field. + channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string) + // channelmonitorDescGroupName is the schema descriptor for group_name field. + channelmonitorDescGroupName := channelmonitorFields[6].Descriptor() + // channelmonitor.DefaultGroupName holds the default value on creation for the group_name field. + channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string) + // channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save. + channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error) + // channelmonitorDescEnabled is the schema descriptor for enabled field. + channelmonitorDescEnabled := channelmonitorFields[7].Descriptor() + // channelmonitor.DefaultEnabled holds the default value on creation for the enabled field. + channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool) + // channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field. + channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor() + // channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save. + channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error) + // channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field. + channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor() + // channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field. + channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string) + // channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field. + channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor() + // channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field. + channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string) + // channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save. + channelmonitor.BodyOverrideModeValidator = channelmonitorDescBodyOverrideMode.Validators[0].(func(string) error) + channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields() + _ = channelmonitordailyrollupFields + // channelmonitordailyrollupDescModel is the schema descriptor for model field. + channelmonitordailyrollupDescModel := channelmonitordailyrollupFields[1].Descriptor() + // channelmonitordailyrollup.ModelValidator is a validator for the "model" field. It is called by the builders before save. + channelmonitordailyrollup.ModelValidator = func() func(string) error { + validators := channelmonitordailyrollupDescModel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(model string) error { + for _, fn := range fns { + if err := fn(model); err != nil { + return err + } + } + return nil + } + }() + // channelmonitordailyrollupDescTotalChecks is the schema descriptor for total_checks field. + channelmonitordailyrollupDescTotalChecks := channelmonitordailyrollupFields[3].Descriptor() + // channelmonitordailyrollup.DefaultTotalChecks holds the default value on creation for the total_checks field. + channelmonitordailyrollup.DefaultTotalChecks = channelmonitordailyrollupDescTotalChecks.Default.(int) + // channelmonitordailyrollupDescOkCount is the schema descriptor for ok_count field. + channelmonitordailyrollupDescOkCount := channelmonitordailyrollupFields[4].Descriptor() + // channelmonitordailyrollup.DefaultOkCount holds the default value on creation for the ok_count field. + channelmonitordailyrollup.DefaultOkCount = channelmonitordailyrollupDescOkCount.Default.(int) + // channelmonitordailyrollupDescOperationalCount is the schema descriptor for operational_count field. + channelmonitordailyrollupDescOperationalCount := channelmonitordailyrollupFields[5].Descriptor() + // channelmonitordailyrollup.DefaultOperationalCount holds the default value on creation for the operational_count field. + channelmonitordailyrollup.DefaultOperationalCount = channelmonitordailyrollupDescOperationalCount.Default.(int) + // channelmonitordailyrollupDescDegradedCount is the schema descriptor for degraded_count field. + channelmonitordailyrollupDescDegradedCount := channelmonitordailyrollupFields[6].Descriptor() + // channelmonitordailyrollup.DefaultDegradedCount holds the default value on creation for the degraded_count field. + channelmonitordailyrollup.DefaultDegradedCount = channelmonitordailyrollupDescDegradedCount.Default.(int) + // channelmonitordailyrollupDescFailedCount is the schema descriptor for failed_count field. + channelmonitordailyrollupDescFailedCount := channelmonitordailyrollupFields[7].Descriptor() + // channelmonitordailyrollup.DefaultFailedCount holds the default value on creation for the failed_count field. + channelmonitordailyrollup.DefaultFailedCount = channelmonitordailyrollupDescFailedCount.Default.(int) + // channelmonitordailyrollupDescErrorCount is the schema descriptor for error_count field. + channelmonitordailyrollupDescErrorCount := channelmonitordailyrollupFields[8].Descriptor() + // channelmonitordailyrollup.DefaultErrorCount holds the default value on creation for the error_count field. + channelmonitordailyrollup.DefaultErrorCount = channelmonitordailyrollupDescErrorCount.Default.(int) + // channelmonitordailyrollupDescSumLatencyMs is the schema descriptor for sum_latency_ms field. + channelmonitordailyrollupDescSumLatencyMs := channelmonitordailyrollupFields[9].Descriptor() + // channelmonitordailyrollup.DefaultSumLatencyMs holds the default value on creation for the sum_latency_ms field. + channelmonitordailyrollup.DefaultSumLatencyMs = channelmonitordailyrollupDescSumLatencyMs.Default.(int64) + // channelmonitordailyrollupDescCountLatency is the schema descriptor for count_latency field. + channelmonitordailyrollupDescCountLatency := channelmonitordailyrollupFields[10].Descriptor() + // channelmonitordailyrollup.DefaultCountLatency holds the default value on creation for the count_latency field. + channelmonitordailyrollup.DefaultCountLatency = channelmonitordailyrollupDescCountLatency.Default.(int) + // channelmonitordailyrollupDescSumPingLatencyMs is the schema descriptor for sum_ping_latency_ms field. + channelmonitordailyrollupDescSumPingLatencyMs := channelmonitordailyrollupFields[11].Descriptor() + // channelmonitordailyrollup.DefaultSumPingLatencyMs holds the default value on creation for the sum_ping_latency_ms field. + channelmonitordailyrollup.DefaultSumPingLatencyMs = channelmonitordailyrollupDescSumPingLatencyMs.Default.(int64) + // channelmonitordailyrollupDescCountPingLatency is the schema descriptor for count_ping_latency field. + channelmonitordailyrollupDescCountPingLatency := channelmonitordailyrollupFields[12].Descriptor() + // channelmonitordailyrollup.DefaultCountPingLatency holds the default value on creation for the count_ping_latency field. + channelmonitordailyrollup.DefaultCountPingLatency = channelmonitordailyrollupDescCountPingLatency.Default.(int) + // channelmonitordailyrollupDescComputedAt is the schema descriptor for computed_at field. + channelmonitordailyrollupDescComputedAt := channelmonitordailyrollupFields[13].Descriptor() + // channelmonitordailyrollup.DefaultComputedAt holds the default value on creation for the computed_at field. + channelmonitordailyrollup.DefaultComputedAt = channelmonitordailyrollupDescComputedAt.Default.(func() time.Time) + // channelmonitordailyrollup.UpdateDefaultComputedAt holds the default value on update for the computed_at field. + channelmonitordailyrollup.UpdateDefaultComputedAt = channelmonitordailyrollupDescComputedAt.UpdateDefault.(func() time.Time) + channelmonitorhistoryFields := schema.ChannelMonitorHistory{}.Fields() + _ = channelmonitorhistoryFields + // channelmonitorhistoryDescModel is the schema descriptor for model field. + channelmonitorhistoryDescModel := channelmonitorhistoryFields[1].Descriptor() + // channelmonitorhistory.ModelValidator is a validator for the "model" field. It is called by the builders before save. + channelmonitorhistory.ModelValidator = func() func(string) error { + validators := channelmonitorhistoryDescModel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(model string) error { + for _, fn := range fns { + if err := fn(model); err != nil { + return err + } + } + return nil + } + }() + // channelmonitorhistoryDescMessage is the schema descriptor for message field. + channelmonitorhistoryDescMessage := channelmonitorhistoryFields[5].Descriptor() + // channelmonitorhistory.DefaultMessage holds the default value on creation for the message field. + channelmonitorhistory.DefaultMessage = channelmonitorhistoryDescMessage.Default.(string) + // channelmonitorhistory.MessageValidator is a validator for the "message" field. It is called by the builders before save. + channelmonitorhistory.MessageValidator = channelmonitorhistoryDescMessage.Validators[0].(func(string) error) + // channelmonitorhistoryDescCheckedAt is the schema descriptor for checked_at field. + channelmonitorhistoryDescCheckedAt := channelmonitorhistoryFields[6].Descriptor() + // channelmonitorhistory.DefaultCheckedAt holds the default value on creation for the checked_at field. + channelmonitorhistory.DefaultCheckedAt = channelmonitorhistoryDescCheckedAt.Default.(func() time.Time) + channelmonitorrequesttemplateMixin := schema.ChannelMonitorRequestTemplate{}.Mixin() + channelmonitorrequesttemplateMixinFields0 := channelmonitorrequesttemplateMixin[0].Fields() + _ = channelmonitorrequesttemplateMixinFields0 + channelmonitorrequesttemplateFields := schema.ChannelMonitorRequestTemplate{}.Fields() + _ = channelmonitorrequesttemplateFields + // channelmonitorrequesttemplateDescCreatedAt is the schema descriptor for created_at field. + channelmonitorrequesttemplateDescCreatedAt := channelmonitorrequesttemplateMixinFields0[0].Descriptor() + // channelmonitorrequesttemplate.DefaultCreatedAt holds the default value on creation for the created_at field. + channelmonitorrequesttemplate.DefaultCreatedAt = channelmonitorrequesttemplateDescCreatedAt.Default.(func() time.Time) + // channelmonitorrequesttemplateDescUpdatedAt is the schema descriptor for updated_at field. + channelmonitorrequesttemplateDescUpdatedAt := channelmonitorrequesttemplateMixinFields0[1].Descriptor() + // channelmonitorrequesttemplate.DefaultUpdatedAt holds the default value on creation for the updated_at field. + channelmonitorrequesttemplate.DefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.Default.(func() time.Time) + // channelmonitorrequesttemplate.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + channelmonitorrequesttemplate.UpdateDefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.UpdateDefault.(func() time.Time) + // channelmonitorrequesttemplateDescName is the schema descriptor for name field. + channelmonitorrequesttemplateDescName := channelmonitorrequesttemplateFields[0].Descriptor() + // channelmonitorrequesttemplate.NameValidator is a validator for the "name" field. It is called by the builders before save. + channelmonitorrequesttemplate.NameValidator = func() func(string) error { + validators := channelmonitorrequesttemplateDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // channelmonitorrequesttemplateDescDescription is the schema descriptor for description field. + channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor() + // channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field. + channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string) + // channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save. + channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error) + // channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field. + channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor() + // channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field. + channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string) + // channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field. + channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor() + // channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field. + channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string) + // channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save. + channelmonitorrequesttemplate.BodyOverrideModeValidator = channelmonitorrequesttemplateDescBodyOverrideMode.Validators[0].(func(string) error) errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() _ = errorpassthroughruleMixinFields0 @@ -477,6 +845,10 @@ func init() { groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) + // groupDescRpmLimit is the schema descriptor for rpm_limit field. + groupDescRpmLimit := groupFields[27].Descriptor() + // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 @@ -512,6 +884,33 @@ func init() { idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) + identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin() + identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields() + _ = identityadoptiondecisionMixinFields0 + identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields() + _ = identityadoptiondecisionFields + // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field. + identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor() + // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field. + identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time) + // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field. + identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor() + // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field. + identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time) + // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time) + // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field. + identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor() + // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field. + identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool) + // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field. + identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor() + // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field. + identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool) + // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field. + identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor() + // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field. + identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time) paymentauditlogFields := schema.PaymentAuditLog{}.Fields() _ = paymentauditlogFields // paymentauditlogDescOrderID is the schema descriptor for order_id field. @@ -578,38 +977,42 @@ func init() { paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor() // paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error) + // paymentorderDescProviderKey is the schema descriptor for provider_key field. + paymentorderDescProviderKey := paymentorderFields[19].Descriptor() + // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error) // paymentorderDescStatus is the schema descriptor for status field. - paymentorderDescStatus := paymentorderFields[19].Descriptor() + paymentorderDescStatus := paymentorderFields[21].Descriptor() // paymentorder.DefaultStatus holds the default value on creation for the status field. paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string) // paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save. paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error) // paymentorderDescRefundAmount is the schema descriptor for refund_amount field. - paymentorderDescRefundAmount := paymentorderFields[20].Descriptor() + paymentorderDescRefundAmount := paymentorderFields[22].Descriptor() // paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field. paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64) // paymentorderDescForceRefund is the schema descriptor for force_refund field. - paymentorderDescForceRefund := paymentorderFields[23].Descriptor() + paymentorderDescForceRefund := paymentorderFields[25].Descriptor() // paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field. paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool) // paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field. - paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor() + paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor() // paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save. paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error) // paymentorderDescClientIP is the schema descriptor for client_ip field. - paymentorderDescClientIP := paymentorderFields[32].Descriptor() + paymentorderDescClientIP := paymentorderFields[34].Descriptor() // paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save. paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error) // paymentorderDescSrcHost is the schema descriptor for src_host field. - paymentorderDescSrcHost := paymentorderFields[33].Descriptor() + paymentorderDescSrcHost := paymentorderFields[35].Descriptor() // paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save. paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error) // paymentorderDescCreatedAt is the schema descriptor for created_at field. - paymentorderDescCreatedAt := paymentorderFields[35].Descriptor() + paymentorderDescCreatedAt := paymentorderFields[37].Descriptor() // paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field. paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time) // paymentorderDescUpdatedAt is the schema descriptor for updated_at field. - paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor() + paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor() // paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time) // paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -682,6 +1085,113 @@ func init() { paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time) + pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin() + pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields() + _ = pendingauthsessionMixinFields0 + pendingauthsessionFields := schema.PendingAuthSession{}.Fields() + _ = pendingauthsessionFields + // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field. + pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor() + // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field. + pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time) + // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field. + pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor() + // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field. + pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time) + // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time) + // pendingauthsessionDescSessionToken is the schema descriptor for session_token field. + pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor() + // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + pendingauthsession.SessionTokenValidator = func() func(string) error { + validators := pendingauthsessionDescSessionToken.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(session_token string) error { + for _, fn := range fns { + if err := fn(session_token); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescIntent is the schema descriptor for intent field. + pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor() + // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save. + pendingauthsession.IntentValidator = func() func(string) error { + validators := pendingauthsessionDescIntent.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(intent string) error { + for _, fn := range fns { + if err := fn(intent); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderType is the schema descriptor for provider_type field. + pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor() + // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + pendingauthsession.ProviderTypeValidator = func() func(string) error { + validators := pendingauthsessionDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field. + pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor() + // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error) + // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field. + pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor() + // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error) + // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field. + pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor() + // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field. + pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string) + // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field. + pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor() + // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field. + pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string) + // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field. + pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor() + // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field. + pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string) + // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field. + pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor() + // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field. + pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{}) + // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field. + pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor() + // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field. + pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{}) + // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field. + pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor() + // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field. + pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string) + // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field. + pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor() + // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field. + pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -1297,22 +1807,32 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSignupSource is the schema descriptor for signup_source field. + userDescSignupSource := userFields[11].Descriptor() + // user.DefaultSignupSource holds the default value on creation for the signup_source field. + user.DefaultSignupSource = userDescSignupSource.Default.(string) + // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error) // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. - userDescBalanceNotifyEnabled := userFields[11].Descriptor() + userDescBalanceNotifyEnabled := userFields[14].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. - userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + userDescBalanceNotifyThresholdType := userFields[15].Descriptor() // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[17].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) // userDescTotalRecharged is the schema descriptor for total_recharged field. - userDescTotalRecharged := userFields[15].Descriptor() + userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) + // userDescRpmLimit is the schema descriptor for rpm_limit field. + userDescRpmLimit := userFields[19].Descriptor() + // user.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + user.DefaultRpmLimit = userDescRpmLimit.Default.(int) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go new file mode 100644 index 0000000000000000000000000000000000000000..0b1b56ab0fd9de365a3fa7533255b4df1882b413 --- /dev/null +++ b/backend/ent/schema/auth_identity.go @@ -0,0 +1,94 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var authProviderTypes = map[string]struct{}{ + "email": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, +} + +func validateAuthProviderType(value string) error { + if _, ok := authProviderTypes[value]; ok { + return nil + } + return fmt.Errorf("invalid auth provider type %q", value) +} + +// AuthIdentity stores the canonical login identity for an account. +type AuthIdentity struct { + ent.Schema +} + +func (AuthIdentity) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identities"}, + } +} + +func (AuthIdentity) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentity) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.String("issuer"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("auth_identities"). + Field("user_id"). + Required(). + Unique(), + edge.To("channels", AuthIdentityChannel.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), + edge.To("adoption_decisions", IdentityAdoptionDecision.Type), + } +} + +func (AuthIdentity) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "provider_subject").Unique(), + index.Fields("user_id"), + index.Fields("user_id", "provider_type"), + } +} diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go new file mode 100644 index 0000000000000000000000000000000000000000..69f2ad028f3249331ec8e2a3ecdcfbab93e1acaf --- /dev/null +++ b/backend/ent/schema/auth_identity_channel.go @@ -0,0 +1,72 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity. +type AuthIdentityChannel struct { + ent.Schema +} + +func (AuthIdentityChannel) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identity_channels"}, + } +} + +func (AuthIdentityChannel) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentityChannel) Fields() []ent.Field { + return []ent.Field{ + field.Int64("identity_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel"). + MaxLen(20). + NotEmpty(), + field.String("channel_app_id"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentityChannel) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("identity", AuthIdentity.Type). + Ref("channels"). + Field("identity_id"). + Required(). + Unique(), + } +} + +func (AuthIdentityChannel) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fbb932368a5abaa41e17dfc508166d98fd9217fe --- /dev/null +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -0,0 +1,168 @@ +package schema + +import ( + "testing" + + "entgo.io/ent" + "entgo.io/ent/entc/load" + "entgo.io/ent/schema/field" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationSchemas(t *testing.T) { + spec, err := (&load.Config{Path: "."}).Load() + require.NoError(t, err) + + schemas := map[string]*load.Schema{} + for _, schema := range spec.Schemas { + schemas[schema.Name] = schema + } + + authIdentity := requireSchema(t, schemas, "AuthIdentity") + requireSchemaFields(t, authIdentity, + "user_id", + "provider_type", + "provider_key", + "provider_subject", + "verified_at", + "issuer", + "metadata", + ) + requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject") + + authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel") + requireSchemaFields(t, authIdentityChannel, + "identity_id", + "provider_type", + "provider_key", + "channel", + "channel_app_id", + "channel_subject", + "metadata", + ) + requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject") + + pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession") + requireSchemaFields(t, pendingAuthSession, + "intent", + "provider_type", + "provider_key", + "provider_subject", + "target_user_id", + "redirect_to", + "resolved_email", + "registration_password_hash", + "upstream_identity_claims", + "local_flow_state", + "browser_session_key", + "completion_code_hash", + "completion_code_expires_at", + "email_verified_at", + "password_verified_at", + "totp_verified_at", + "expires_at", + "consumed_at", + ) + + adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision") + requireSchemaFields(t, adoptionDecision, + "pending_auth_session_id", + "identity_id", + "adopt_display_name", + "adopt_avatar", + "decided_at", + ) + requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id") + + userSchema := requireSchema(t, schemas, "User") + requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") + signupSource := requireSchemaField(t, userSchema, "signup_source") + require.Equal(t, field.TypeString, signupSource.Info.Type) + require.True(t, signupSource.Default) + require.Equal(t, "email", signupSource.DefaultValue) + require.Equal(t, 1, signupSource.Validators) + + validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source") + for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} { + require.NoError(t, validator(value)) + } + require.Error(t, validator("github")) +} + +func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { + t.Helper() + + schema, ok := schemas[name] + require.True(t, ok, "schema %s should exist", name) + return schema +} + +func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { + t.Helper() + + fields := map[string]struct{}{} + for _, field := range schema.Fields { + fields[field.Name] = struct{}{} + } + + for _, name := range names { + _, ok := fields[name] + require.True(t, ok, "schema %s should include field %s", schema.Name, name) + } +} + +func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field { + t.Helper() + + for _, schemaField := range schema.Fields { + if schemaField.Name == name { + return schemaField + } + } + + require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name) + return nil +} + +func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error { + t.Helper() + + for _, entField := range fields { + descriptor := entField.Descriptor() + if descriptor.Name != name { + continue + } + require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name) + validator, ok := descriptor.Validators[0].(func(string) error) + require.True(t, ok, "field %s validator should be func(string) error", name) + return validator + } + + require.Failf(t, "missing field validator", "schema should include field %s", name) + return nil +} + +func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { + t.Helper() + + for _, index := range schema.Indexes { + if !index.Unique { + continue + } + if len(index.Fields) != len(fields) { + continue + } + match := true + for i := range fields { + if index.Fields[i] != fields[i] { + match = false + break + } + } + if match { + return + } + } + + require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields) +} diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go new file mode 100644 index 0000000000000000000000000000000000000000..355ade4b686383207c2041202bccd673b63c4256 --- /dev/null +++ b/backend/ent/schema/channel_monitor.go @@ -0,0 +1,110 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ChannelMonitor holds the schema definition for the ChannelMonitor entity. +// 渠道监控配置:定期对指定 provider/endpoint/api_key 下的模型做心跳测试。 +type ChannelMonitor struct { + ent.Schema +} + +func (ChannelMonitor) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "channel_monitors"}, + } +} + +func (ChannelMonitor) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (ChannelMonitor) Fields() []ent.Field { + return []ent.Field{ + field.String("name"). + NotEmpty(). + MaxLen(100), + field.Enum("provider"). + Values("openai", "anthropic", "gemini"), + field.String("endpoint"). + NotEmpty(). + MaxLen(500). + Comment("Provider base origin, e.g. https://api.openai.com"), + field.String("api_key_encrypted"). + NotEmpty(). + Sensitive(). + Comment("AES-256-GCM encrypted API key"), + field.String("primary_model"). + NotEmpty(). + MaxLen(200), + field.JSON("extra_models", []string{}). + Default([]string{}). + Comment("Additional model names to test alongside primary_model"), + field.String("group_name"). + Optional(). + Default(""). + MaxLen(100), + field.Bool("enabled"). + Default(true), + field.Int("interval_seconds"). + Range(15, 3600), + field.Time("last_checked_at"). + Optional(). + Nillable(), + field.Int64("created_by"), + + // ---- 自定义请求快照字段(来自模板 / 手动编辑) ---- + + // template_id: 关联的请求模板 ID(仅用于 UI 分组 + 一键应用)。 + // 实际运行时 checker 只读下面 3 个快照字段,**不再回查模板表**。 + // 模板被删除时此字段会被 SET NULL(见 Edges 的 OnDelete 注解)。 + field.Int64("template_id"). + Optional(). + Nillable(), + // extra_headers: 自定义 HTTP 头快照(来自模板 or 用户手填)。 + // 运行时 merge 进 adapter 默认 headers。 + field.JSON("extra_headers", map[string]string{}). + Default(map[string]string{}), + // body_override_mode: 同 ChannelMonitorRequestTemplate.body_override_mode + field.String("body_override_mode"). + Default("off"). + MaxLen(10), + // body_override: 同 ChannelMonitorRequestTemplate.body_override + field.JSON("body_override", map[string]any{}). + Optional(), + } +} + +func (ChannelMonitor) Edges() []ent.Edge { + return []ent.Edge{ + edge.To("history", ChannelMonitorHistory.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), + edge.To("daily_rollups", ChannelMonitorDailyRollup.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), + // 关联请求模板:模板被删除时 template_id 自动置空, + // 监控本身保留(继续用快照字段跑)。 + edge.To("request_template", ChannelMonitorRequestTemplate.Type). + Field("template_id"). + Unique(). + Annotations(entsql.OnDelete(entsql.SetNull)), + } +} + +func (ChannelMonitor) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("enabled", "last_checked_at"), + index.Fields("provider"), + index.Fields("group_name"), + index.Fields("template_id"), + } +} diff --git a/backend/ent/schema/channel_monitor_daily_rollup.go b/backend/ent/schema/channel_monitor_daily_rollup.go new file mode 100644 index 0000000000000000000000000000000000000000..23f032e317e9544b9d653607e547c7b0dc588e75 --- /dev/null +++ b/backend/ent/schema/channel_monitor_daily_rollup.go @@ -0,0 +1,66 @@ +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。 +// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率 +// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。 +// 超过保留期由每日维护任务分批物理删(不用软删除,理由同 channel_monitor_history)。 +type ChannelMonitorDailyRollup struct { + ent.Schema +} + +func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "channel_monitor_daily_rollups"}, + } +} + +func (ChannelMonitorDailyRollup) Fields() []ent.Field { + return []ent.Field{ + field.Int64("monitor_id"), + field.String("model"). + NotEmpty(). + MaxLen(200), + field.Time("bucket_date"). + SchemaType(map[string]string{dialect.Postgres: "date"}), + field.Int("total_checks").Default(0), + field.Int("ok_count").Default(0), + field.Int("operational_count").Default(0), + field.Int("degraded_count").Default(0), + field.Int("failed_count").Default(0), + field.Int("error_count").Default(0), + field.Int64("sum_latency_ms").Default(0), + field.Int("count_latency").Default(0), + field.Int64("sum_ping_latency_ms").Default(0), + field.Int("count_ping_latency").Default(0), + field.Time("computed_at").Default(time.Now).UpdateDefault(time.Now), + } +} + +func (ChannelMonitorDailyRollup) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("monitor", ChannelMonitor.Type). + Ref("daily_rollups"). + Field("monitor_id"). + Unique(). + Required(), + } +} + +func (ChannelMonitorDailyRollup) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("monitor_id", "model", "bucket_date").Unique(), + index.Fields("bucket_date"), + } +} diff --git a/backend/ent/schema/channel_monitor_history.go b/backend/ent/schema/channel_monitor_history.go new file mode 100644 index 0000000000000000000000000000000000000000..4366e79a672fe7f87f78ef4b85ebeb5fe9b1c3bc --- /dev/null +++ b/backend/ent/schema/channel_monitor_history.go @@ -0,0 +1,66 @@ +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity. +// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天由每日维护任务 +// 先聚合到 channel_monitor_daily_rollups,再分批物理删(不用软删除:日志类表无恢复 +// 需求,软删会让行和索引只增不减,徒增磁盘和查询开销)。 +type ChannelMonitorHistory struct { + ent.Schema +} + +func (ChannelMonitorHistory) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "channel_monitor_histories"}, + } +} + +func (ChannelMonitorHistory) Fields() []ent.Field { + return []ent.Field{ + field.Int64("monitor_id"), + field.String("model"). + NotEmpty(). + MaxLen(200), + field.Enum("status"). + Values("operational", "degraded", "failed", "error"), + field.Int("latency_ms"). + Optional(). + Nillable(), + field.Int("ping_latency_ms"). + Optional(). + Nillable(), + field.String("message"). + Optional(). + Default(""). + MaxLen(500), + field.Time("checked_at"). + Default(time.Now), + } +} + +func (ChannelMonitorHistory) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("monitor", ChannelMonitor.Type). + Ref("history"). + Field("monitor_id"). + Unique(). + Required(), + } +} + +func (ChannelMonitorHistory) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("monitor_id", "model", "checked_at"), + index.Fields("checked_at"), + } +} diff --git a/backend/ent/schema/channel_monitor_request_template.go b/backend/ent/schema/channel_monitor_request_template.go new file mode 100644 index 0000000000000000000000000000000000000000..59df2f29d0ecba627b512b8ded1fb488a837d133 --- /dev/null +++ b/backend/ent/schema/channel_monitor_request_template.go @@ -0,0 +1,80 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ChannelMonitorRequestTemplate 请求模板:一组可复用的 headers + 可选 body 覆盖配置。 +// +// 语义为快照:模板被"应用"到监控时,extra_headers / body_override_mode / body_override +// 会被**拷贝**到 channel_monitors 同名字段;后续模板变动不会自动影响已应用的监控—— +// 必须用户主动在模板编辑 Dialog 里点「应用到关联监控」才会覆盖快照。 +// 这样模板改错不会瞬间打挂所有已经跑起来的监控。 +type ChannelMonitorRequestTemplate struct { + ent.Schema +} + +func (ChannelMonitorRequestTemplate) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "channel_monitor_request_templates"}, + } +} + +func (ChannelMonitorRequestTemplate) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (ChannelMonitorRequestTemplate) Fields() []ent.Field { + return []ent.Field{ + field.String("name"). + NotEmpty(). + MaxLen(100), + field.Enum("provider"). + Values("openai", "anthropic", "gemini"), + field.String("description"). + Optional(). + Default(""). + MaxLen(500), + // extra_headers: 用户自定义 HTTP 头(如 User-Agent 伪装)。 + // 运行时 merge 进 adapter 默认 headers,用户值优先; + // hop-by-hop 黑名单(Host/Content-Length/...)由 checker 过滤。 + field.JSON("extra_headers", map[string]string{}). + Default(map[string]string{}), + // body_override_mode: 'off' | 'merge' | 'replace' + // off - 用 adapter 默认 body(忽略 body_override) + // merge - adapter 默认 body 与 body_override 浅合并(body_override 优先, + // model/messages/contents 等关键字段在 checker 里走黑名单跳过) + // replace - 直接用 body_override 作为完整 body;此时跳过 challenge 校验, + // 改为 HTTP 2xx + 响应文本非空即视为可用 + field.String("body_override_mode"). + Default("off"). + MaxLen(10), + // body_override: JSON 对象,根据 body_override_mode 使用。 + // 用 map[string]any 以便前端传任意结构(含嵌套)。 + field.JSON("body_override", map[string]any{}). + Optional(), + } +} + +func (ChannelMonitorRequestTemplate) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("monitors", ChannelMonitor.Type). + Ref("request_template"), + } +} + +func (ChannelMonitorRequestTemplate) Indexes() []ent.Index { + return []ent.Index{ + // 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。 + index.Fields("provider", "name").Unique(), + } +} diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d78a68986dea728ef984ecd709c84e7636722d3f..11f38d66f0403ee382a28c4a88aaea92aabbc79a 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field { Default(domain.OpenAIMessagesDispatchModelConfig{}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), + + // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。 + field.Int("rpm_limit"). + Default(0). + Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"), } } diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go new file mode 100644 index 0000000000000000000000000000000000000000..9fdd26fbca5120d4ddc17caed5cdfe82cfcedf27 --- /dev/null +++ b/backend/ent/schema/identity_adoption_decision.go @@ -0,0 +1,70 @@ +package schema + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow. +type IdentityAdoptionDecision struct { + ent.Schema +} + +func (IdentityAdoptionDecision) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "identity_adoption_decisions"}, + } +} + +func (IdentityAdoptionDecision) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdentityAdoptionDecision) Fields() []ent.Field { + return []ent.Field{ + field.Int64("pending_auth_session_id"), + field.Int64("identity_id"). + Optional(). + Nillable(), + field.Bool("adopt_display_name"). + Default(false), + field.Bool("adopt_avatar"). + Default(false), + field.Time("decided_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (IdentityAdoptionDecision) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("pending_auth_session", PendingAuthSession.Type). + Ref("adoption_decision"). + Field("pending_auth_session_id"). + Required(). + Unique(), + edge.From("identity", AuthIdentity.Type). + Ref("adoption_decisions"). + Field("identity_id"). + Unique(), + } +} + +func (IdentityAdoptionDecision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("pending_auth_session_id").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go index a9576d2ab02aca442249de0457ea4321cf15f85e..d25d1e5e17c3f85676d0ff25ab3e1096a9896263 100644 --- a/backend/ent/schema/payment_order.go +++ b/backend/ent/schema/payment_order.go @@ -91,6 +91,13 @@ func (PaymentOrder) Fields() []ent.Field { Optional(). Nillable(). MaxLen(64), + field.String("provider_key"). + Optional(). + Nillable(). + MaxLen(30), + field.JSON("provider_snapshot", map[string]any{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), // 状态 field.String("status"). @@ -178,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge { func (PaymentOrder) Indexes() []ent.Index { return []ent.Index{ - index.Fields("out_trade_no"), + index.Fields("out_trade_no"). + Unique(). + Annotations(entsql.IndexWhere("out_trade_no <> ''")), index.Fields("user_id"), index.Fields("status"), index.Fields("expires_at"), diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go new file mode 100644 index 0000000000000000000000000000000000000000..7e95f08512adfedd204a8fa23598583d85e4d5a5 --- /dev/null +++ b/backend/ent/schema/pending_auth_session.go @@ -0,0 +1,135 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var pendingAuthIntents = map[string]struct{}{ + "login": {}, + "bind_current_user": {}, + "adopt_existing_user_by_email": {}, +} + +func validatePendingAuthIntent(value string) error { + if _, ok := pendingAuthIntents[value]; ok { + return nil + } + return fmt.Errorf("invalid pending auth intent %q", value) +} + +// PendingAuthSession stores a short-lived post-auth decision session. +type PendingAuthSession struct { + ent.Schema +} + +func (PendingAuthSession) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "pending_auth_sessions"}, + } +} + +func (PendingAuthSession) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (PendingAuthSession) Fields() []ent.Field { + return []ent.Field{ + field.String("session_token"). + MaxLen(255). + NotEmpty(), + field.String("intent"). + MaxLen(40). + NotEmpty(). + Validate(validatePendingAuthIntent), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("target_user_id"). + Optional(). + Nillable(), + field.String("redirect_to"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("resolved_email"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("registration_password_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("upstream_identity_claims", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.JSON("local_flow_state", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.String("browser_session_key"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("completion_code_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("completion_code_expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("email_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("password_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("totp_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("consumed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (PendingAuthSession) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("target_user", User.Type). + Ref("pending_auth_sessions"). + Field("target_user_id"). + Unique(), + edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Annotations(entsql.OnDelete(entsql.Cascade)). + Unique(), + } +} + +func (PendingAuthSession) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("session_token").Unique(), + index.Fields("target_user_id"), + index.Fields("expires_at"), + index.Fields("provider_type", "provider_key", "provider_subject"), + index.Fields("completion_code_hash"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ef52e985d91d84e568518cbecb1684251d82652c..83da5c32ba8f0feef20fae242be40a98aa7c1f27 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -1,6 +1,8 @@ package schema import ( + "fmt" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -72,6 +74,24 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + field.String("signup_source"). + Validate(func(value string) error { + switch value { + case "email", "linuxdo", "wechat", "oidc": + return nil + default: + return fmt.Errorf("must be one of email, linuxdo, wechat, oidc") + } + }). + Default("email"), + field.Time("last_login_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("last_active_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), // 余额不足通知 field.Bool("balance_notify_enabled"). @@ -88,6 +108,10 @@ func (User) Fields() []ent.Field { field.Float("total_recharged"). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). Default(0), + + // 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。 + field.Int("rpm_limit"). + Default(0), } } @@ -104,6 +128,9 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), + edge.To("auth_identities", AuthIdentity.Type). + Annotations(entsql.OnDelete(entsql.Cascade)), + edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index bb3139d5c84119195c82c071194ca4b8663fbab0..611028e9157dbd470cf8c29de272823dcdfc045e 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,18 +24,34 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient + // ChannelMonitor is the client for interacting with the ChannelMonitor builders. + ChannelMonitor *ChannelMonitorClient + // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders. + ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient + // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders. + ChannelMonitorHistory *ChannelMonitorHistoryClient + // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders. + ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -202,12 +218,20 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.AuthIdentity = NewAuthIdentityClient(tx.config) + tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config) + tx.ChannelMonitor = NewChannelMonitorClient(tx.config) + tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config) + tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config) + tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) + tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) + tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 9fa91f74b9dda8b6bbeedc5f1f6fdf931df018d3..06670444897fb23b8fd3bc329aa27efaddfed386 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,12 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SignupSource holds the value of the "signup_source" field. + SignupSource string `json:"signup_source,omitempty"` + // LastLoginAt holds the value of the "last_login_at" field. + LastLoginAt *time.Time `json:"last_login_at,omitempty"` + // LastActiveAt holds the value of the "last_active_at" field. + LastActiveAt *time.Time `json:"last_active_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. @@ -55,6 +61,8 @@ type User struct { BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"` // TotalRecharged holds the value of the "total_recharged" field. TotalRecharged float64 `json:"total_recharged,omitempty"` + // RpmLimit holds the value of the "rpm_limit" field. + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -83,11 +91,15 @@ type UserEdges struct { PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` // PaymentOrders holds the value of the payment_orders edge. PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` + // AuthIdentities holds the value of the auth_identities edge. + AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` + // PendingAuthSessions holds the value of the pending_auth_sessions edge. + PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [11]bool + loadedTypes [13]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -180,10 +192,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { return nil, &NotLoadedError{edge: "payment_orders"} } +// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) { + if e.loadedTypes[10] { + return e.AuthIdentities, nil + } + return nil, &NotLoadedError{edge: "auth_identities"} +} + +// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { + if e.loadedTypes[11] { + return e.PendingAuthSessions, nil + } + return nil, &NotLoadedError{edge: "pending_auth_sessions"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[10] { + if e.loadedTypes[12] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -198,11 +228,11 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -312,6 +342,26 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSignupSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field signup_source", values[i]) + } else if value.Valid { + _m.SignupSource = value.String + } + case user.FieldLastLoginAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login_at", values[i]) + } else if value.Valid { + _m.LastLoginAt = new(time.Time) + *_m.LastLoginAt = value.Time + } + case user.FieldLastActiveAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_active_at", values[i]) + } else if value.Valid { + _m.LastActiveAt = new(time.Time) + *_m.LastActiveAt = value.Time + } case user.FieldBalanceNotifyEnabled: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) @@ -343,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.TotalRecharged = value.Float64 } + case user.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -406,6 +462,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { return NewUserClient(_m.config).QueryPaymentOrders(_m) } +// QueryAuthIdentities queries the "auth_identities" edge of the User entity. +func (_m *User) QueryAuthIdentities() *AuthIdentityQuery { + return NewUserClient(_m.config).QueryAuthIdentities(_m) +} + +// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity. +func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { + return NewUserClient(_m.config).QueryPendingAuthSessions(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) @@ -482,6 +548,19 @@ func (_m *User) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + builder.WriteString("signup_source=") + builder.WriteString(_m.SignupSource) + builder.WriteString(", ") + if v := _m.LastLoginAt; v != nil { + builder.WriteString("last_login_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.LastActiveAt; v != nil { + builder.WriteString("last_active_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") @@ -498,6 +577,9 @@ func (_m *User) String() string { builder.WriteString(", ") builder.WriteString("total_recharged=") builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index d88a3a380b165f7a8988710ff761b2895aedbec0..e11a8a32e420da63623f1b3ff96dc25b2a5204c8 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,12 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSignupSource holds the string denoting the signup_source field in the database. + FieldSignupSource = "signup_source" + // FieldLastLoginAt holds the string denoting the last_login_at field in the database. + FieldLastLoginAt = "last_login_at" + // FieldLastActiveAt holds the string denoting the last_active_at field in the database. + FieldLastActiveAt = "last_active_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. @@ -53,6 +59,8 @@ const ( FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails" // FieldTotalRecharged holds the string denoting the total_recharged field in the database. FieldTotalRecharged = "total_recharged" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -73,6 +81,10 @@ const ( EdgePromoCodeUsages = "promo_code_usages" // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. EdgePaymentOrders = "payment_orders" + // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations. + EdgeAuthIdentities = "auth_identities" + // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. + EdgePendingAuthSessions = "pending_auth_sessions" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -145,6 +157,20 @@ const ( PaymentOrdersInverseTable = "payment_orders" // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. PaymentOrdersColumn = "user_id" + // AuthIdentitiesTable is the table that holds the auth_identities relation/edge. + AuthIdentitiesTable = "auth_identities" + // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + AuthIdentitiesInverseTable = "auth_identities" + // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge. + AuthIdentitiesColumn = "user_id" + // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge. + PendingAuthSessionsTable = "pending_auth_sessions" + // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionsInverseTable = "pending_auth_sessions" + // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. + PendingAuthSessionsColumn = "target_user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -171,11 +197,15 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSignupSource, + FieldLastLoginAt, + FieldLastActiveAt, FieldBalanceNotifyEnabled, FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, FieldBalanceNotifyExtraEmails, FieldTotalRecharged, + FieldRpmLimit, } var ( @@ -232,6 +262,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSignupSource holds the default value on creation for the "signup_source" field. + DefaultSignupSource string + // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + SignupSourceValidator func(string) error // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. @@ -240,6 +274,8 @@ var ( DefaultBalanceNotifyExtraEmails string // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field. DefaultTotalRecharged float64 + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the User queries. @@ -320,6 +356,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySignupSource orders the results by the signup_source field. +func BySignupSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignupSource, opts...).ToFunc() +} + +// ByLastLoginAt orders the results by the last_login_at field. +func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc() +} + +// ByLastActiveAt orders the results by the last_active_at field. +func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc() +} + // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() @@ -345,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -485,6 +541,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByAuthIdentitiesCount orders the results by auth_identities count. +func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...) + } +} + +// ByAuthIdentities orders the results by auth_identities terms. +func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count. +func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...) + } +} + +// ByPendingAuthSessions orders the results by pending_auth_sessions terms. +func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -568,6 +652,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), ) } +func newAuthIdentitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AuthIdentitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) +} +func newPendingAuthSessionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 2788aa7adc4f41e98c35e1d3810673632c0effcc..05d3b35b9628706b49762ca8e2af89bbef403bc2 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ. +func SignupSource(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ. +func LastLoginAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ. +func LastActiveAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -150,6 +165,11 @@ func TotalRecharged(v float64) predicate.User { return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -885,6 +905,171 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SignupSourceEQ applies the EQ predicate on the "signup_source" field. +func SignupSourceEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field. +func SignupSourceNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSignupSource, v)) +} + +// SignupSourceIn applies the In predicate on the "signup_source" field. +func SignupSourceIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldSignupSource, vs...)) +} + +// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field. +func SignupSourceNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...)) +} + +// SignupSourceGT applies the GT predicate on the "signup_source" field. +func SignupSourceGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldSignupSource, v)) +} + +// SignupSourceGTE applies the GTE predicate on the "signup_source" field. +func SignupSourceGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldSignupSource, v)) +} + +// SignupSourceLT applies the LT predicate on the "signup_source" field. +func SignupSourceLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldSignupSource, v)) +} + +// SignupSourceLTE applies the LTE predicate on the "signup_source" field. +func SignupSourceLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldSignupSource, v)) +} + +// SignupSourceContains applies the Contains predicate on the "signup_source" field. +func SignupSourceContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldSignupSource, v)) +} + +// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field. +func SignupSourceHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v)) +} + +// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field. +func SignupSourceHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v)) +} + +// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field. +func SignupSourceEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldSignupSource, v)) +} + +// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field. +func SignupSourceContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldSignupSource, v)) +} + +// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field. +func LastLoginAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field. +func LastLoginAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtIn applies the In predicate on the "last_login_at" field. +func LastLoginAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field. +func LastLoginAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtGT applies the GT predicate on the "last_login_at" field. +func LastLoginAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastLoginAt, v)) +} + +// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field. +func LastLoginAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastLoginAt, v)) +} + +// LastLoginAtLT applies the LT predicate on the "last_login_at" field. +func LastLoginAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastLoginAt, v)) +} + +// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field. +func LastLoginAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastLoginAt, v)) +} + +// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field. +func LastLoginAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastLoginAt)) +} + +// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field. +func LastLoginAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastLoginAt)) +} + +// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field. +func LastActiveAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field. +func LastActiveAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtIn applies the In predicate on the "last_active_at" field. +func LastActiveAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field. +func LastActiveAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtGT applies the GT predicate on the "last_active_at" field. +func LastActiveAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastActiveAt, v)) +} + +// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field. +func LastActiveAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastActiveAt, v)) +} + +// LastActiveAtLT applies the LT predicate on the "last_active_at" field. +func LastActiveAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastActiveAt, v)) +} + +// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field. +func LastActiveAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastActiveAt, v)) +} + +// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field. +func LastActiveAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastActiveAt)) +} + +// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field. +func LastActiveAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastActiveAt)) +} + // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. func BalanceNotifyEnabledEQ(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -1115,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User { return predicate.User(sql.FieldLTE(FieldTotalRecharged, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.User { + return predicate.User(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.User { + return predicate.User(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.User { + return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.User { + return predicate.User(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.User { + return predicate.User(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.User { + return predicate.User(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.User { + return predicate.User(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { @@ -1345,6 +1570,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { }) } +// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge. +func HasAuthIdentities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates). +func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newAuthIdentitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge. +func HasPendingAuthSessions() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates). +func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPendingAuthSessionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index fbc64f9c46d852017276c26af0071460689b6058..b4161128fd66baee8bf9c6d186f87ba42fe90c14 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSignupSource sets the "signup_source" field. +func (_c *UserCreate) SetSignupSource(v string) *UserCreate { + _c.mutation.SetSignupSource(v) + return _c +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate { + if v != nil { + _c.SetSignupSource(*v) + } + return _c +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate { + _c.mutation.SetLastLoginAt(v) + return _c +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastLoginAt(*v) + } + return _c +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate { + _c.mutation.SetLastActiveAt(v) + return _c +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastActiveAt(*v) + } + return _c +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { _c.mutation.SetBalanceNotifyEnabled(v) @@ -281,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate { return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *UserCreate) SetRpmLimit(v int) *UserCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -431,6 +489,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { return _c.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate { + _c.mutation.AddAuthIdentityIDs(ids...) + return _c +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate { + _c.mutation.AddPendingAuthSessionIDs(ids...) + return _c +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -510,6 +598,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SignupSource(); !ok { + v := user.DefaultSignupSource + _c.mutation.SetSignupSource(v) + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) @@ -526,6 +618,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotalRecharged _c.mutation.SetTotalRecharged(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := user.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -589,6 +685,14 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SignupSource(); !ok { + return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)} + } + if v, ok := _c.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } @@ -601,6 +705,9 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotalRecharged(); !ok { return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)} + } return nil } @@ -684,6 +791,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + _node.SignupSource = value + } + if value, ok := _c.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + _node.LastLoginAt = &value + } + if value, ok := _c.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + _node.LastActiveAt = &value + } if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value @@ -704,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) _node.TotalRecharged = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -868,6 +991,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -1106,6 +1261,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsert) SetSignupSource(v string) *UserUpsert { + u.Set(user.FieldSignupSource, v) + return u +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsert) UpdateSignupSource() *UserUpsert { + u.SetExcluded(user.FieldSignupSource) + return u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastLoginAt, v) + return u +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert { + u.SetExcluded(user.FieldLastLoginAt) + return u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsert) ClearLastLoginAt() *UserUpsert { + u.SetNull(user.FieldLastLoginAt) + return u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastActiveAt, v) + return u +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert { + u.SetExcluded(user.FieldLastActiveAt) + return u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsert) ClearLastActiveAt() *UserUpsert { + u.SetNull(user.FieldLastActiveAt) + return u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { u.Set(user.FieldBalanceNotifyEnabled, v) @@ -1184,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert { + u.Set(user.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsert) UpdateRpmLimit() *UserUpsert { + u.SetExcluded(user.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert { + u.Add(user.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1446,6 +1667,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1537,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1965,6 +2263,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { @@ -2056,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 113d87aca24a273a137326557955a6b4eb60b751..f1ee5cfe0aad9fb821eaab3d810683b5f571c6e2 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -15,8 +15,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -44,6 +46,8 @@ type UserQuery struct { withAttributeValues *UserAttributeValueQuery withPromoCodeUsages *PromoCodeUsageQuery withPaymentOrders *PaymentOrderQuery + withAuthIdentities *AuthIdentityQuery + withPendingAuthSessions *PendingAuthSessionQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { return query } +// QueryAuthIdentities chains the current query on the "auth_identities" edge. +func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge. +func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { withAttributeValues: _q.withAttributeValues.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(), + withAuthIdentities: _q.withAuthIdentities.Clone(), + withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu return _q } +// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to +// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAuthIdentities = query + return _q +} + +// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSessions = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [11]bool{ + loadedTypes = [13]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withAttributeValues != nil, _q.withPromoCodeUsages != nil, _q.withPaymentOrders != nil, + _q.withAuthIdentities != nil, + _q.withPendingAuthSessions != nil, _q.withUserAllowedGroups != nil, } ) @@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withAuthIdentities; query != nil { + if err := _q.loadAuthIdentities(ctx, query, nodes, + func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} }, + func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil { + return nil, err + } + } + if query := _q.withPendingAuthSessions; query != nil { + if err := _q.loadPendingAuthSessions(ctx, query, nodes, + func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} }, + func(n *User, e *PendingAuthSession) { + n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e) + }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ } return nil } +func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentity.FieldUserID) + } + query.Where(predicate.AuthIdentity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID) + } + query.Where(predicate.PendingAuthSession(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TargetUserID + if fk == nil { + return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 6b3552476515700564fe9d17395d548fa99628d3..f1d759ce44025bd4084de3c17b827e19001f99ff 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { _u.mutation.SetBalanceNotifyEnabled(v) @@ -333,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -483,6 +560,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -698,6 +805,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -767,6 +916,11 @@ func (_u *UserUpdate) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -836,6 +990,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -860,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1322,6 +1497,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1548,6 +1813,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { _u.mutation.SetBalanceNotifyEnabled(v) @@ -1638,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1788,6 +2128,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2003,6 +2373,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -2085,6 +2497,11 @@ func (_u *UserUpdateOne) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -2171,6 +2588,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -2195,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -2657,6 +3095,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/go.mod b/backend/go.mod index 66b6cc25b598efeb750d0d5b36153f8c95bd6863..982bf91b916b4d8d85a48ba6a99c89e6ca29e1c5 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -39,10 +39,11 @@ require ( github.com/wechatpay-apiv3/wechatpay-go v0.2.21 github.com/zeromicro/go-zero v1.9.4 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.48.0 - golang.org/x/net v0.49.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.40.0 + golang.org/x/crypto v0.49.0 + golang.org/x/image v0.39.0 + golang.org/x/net v0.52.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.41.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -172,10 +173,10 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.32.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.36.0 // indirect + golang.org/x/tools v0.43.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 9312af63e5ef962c2b48ac55fa7c4daac8309dc1..0f366ee10797bed1d91f2035f65d9c8b37402b17 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -413,16 +413,18 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww= +golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -432,16 +434,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e588f44e0ee4be2b249e09167621bb07bbf..87263db09e670fc6e1ed7884f9508245926580c5 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -52,6 +52,11 @@ const ( ConnectionPoolIsolationAccountProxy = "account_proxy" ) +// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。 +// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。 +// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 +const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024 + type Config struct { Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` @@ -65,6 +70,7 @@ type Config struct { JWT JWTConfig `mapstructure:"jwt"` Totp TotpConfig `mapstructure:"totp"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + WeChat WeChatConnectConfig `mapstructure:"wechat_connect"` OIDC OIDCConnectConfig `mapstructure:"oidc_connect"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` @@ -185,26 +191,47 @@ type LinuxDoConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +type WeChatConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + AppID string `mapstructure:"app_id"` + AppSecret string `mapstructure:"app_secret"` + OpenAppID string `mapstructure:"open_app_id"` + OpenAppSecret string `mapstructure:"open_app_secret"` + MPAppID string `mapstructure:"mp_app_id"` + MPAppSecret string `mapstructure:"mp_app_secret"` + MobileAppID string `mapstructure:"mobile_app_id"` + MobileAppSecret string `mapstructure:"mobile_app_secret"` + OpenEnabled bool `mapstructure:"open_enabled"` + MPEnabled bool `mapstructure:"mp_enabled"` + MobileEnabled bool `mapstructure:"mobile_enabled"` + Mode string `mapstructure:"mode"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` +} + type OIDCConnectConfig struct { - Enabled bool `mapstructure:"enabled"` - ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - IssuerURL string `mapstructure:"issuer_url"` - DiscoveryURL string `mapstructure:"discovery_url"` - AuthorizeURL string `mapstructure:"authorize_url"` - TokenURL string `mapstructure:"token_url"` - UserInfoURL string `mapstructure:"userinfo_url"` - JWKSURL string `mapstructure:"jwks_url"` - Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" - RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) - FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) - TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none - UsePKCE bool `mapstructure:"use_pkce"` - ValidateIDToken bool `mapstructure:"validate_id_token"` - AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" - ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 - RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false + Enabled bool `mapstructure:"enabled"` + ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + IssuerURL string `mapstructure:"issuer_url"` + DiscoveryURL string `mapstructure:"discovery_url"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + JWKSURL string `mapstructure:"jwks_url"` + Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + ValidateIDToken bool `mapstructure:"validate_id_token"` + UsePKCEExplicit bool `mapstructure:"-" yaml:"-"` + ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"` + AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" + ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 + RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 // 为空时,服务端会尝试一组常见字段名。 @@ -213,6 +240,225 @@ type OIDCConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +const ( + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback" +) + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func normalizeWeChatConnectMode(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + case "mobile": + return "mobile" + default: + return defaultWeChatConnectMode + } +} + +func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + mode = normalizeWeChatConnectMode(mode) + switch mode { + case "open": + if openEnabled { + return "open" + } + case "mp": + if mpEnabled { + return "mp" + } + case "mobile": + if mobileEnabled { + return "mobile" + } + } + switch { + case openEnabled: + return "open" + case mpEnabled: + return "mp" + case mobileEnabled: + return "mobile" + default: + return mode + } +} + +func defaultWeChatConnectScopesForMode(mode string) string { + switch normalizeWeChatConnectMode(mode) { + case "mp": + return "snsapi_userinfo" + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func normalizeWeChatConnectScopes(raw, mode string) string { + switch normalizeWeChatConnectMode(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopesForMode(mode) + } + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool { + if viper.InConfig(configKey) { + return false + } + _, hasNewEnv := os.LookupEnv(envKey) + return !hasNewEnv +} + +func hasExplicitConfigOrEnv(configKey, envKey string) bool { + if viper.InConfig(configKey) { + return true + } + _, ok := os.LookupEnv(envKey) + return ok +} + +func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) { + if cfg == nil { + return + } + + legacyOpenAppID := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") { + legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) + if legacyOpenAppID != "" { + cfg.OpenAppID = legacyOpenAppID + } + } + + legacyOpenAppSecret := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") { + legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) + if legacyOpenAppSecret != "" { + cfg.OpenAppSecret = legacyOpenAppSecret + } + } + + legacyMPAppID := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") { + legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) + if legacyMPAppID != "" { + cfg.MPAppID = legacyMPAppID + } + } + + legacyMPAppSecret := "" + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") && + shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") { + legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) + if legacyMPAppSecret != "" { + cfg.MPAppSecret = legacyMPAppSecret + } + } + + if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") { + if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" { + cfg.FrontendRedirectURL = legacyFrontend + } + } + + hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != "" + hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != "" + + if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) { + cfg.Enabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen { + cfg.OpenEnabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP { + cfg.MPEnabled = true + } + if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") { + switch { + case hasLegacyMP && !hasLegacyOpen: + cfg.Mode = "mp" + case hasLegacyOpen: + cfg.Mode = "open" + } + } + if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") { + switch { + case hasLegacyMP && !hasLegacyOpen: + cfg.Scopes = defaultWeChatConnectScopesForMode("mp") + case hasLegacyOpen: + cfg.Scopes = defaultWeChatConnectScopesForMode("open") + } + } +} + +func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) { + if cfg == nil { + return + } + + cfg.AppID = strings.TrimSpace(cfg.AppID) + cfg.AppSecret = strings.TrimSpace(cfg.AppSecret) + cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID) + cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret) + cfg.MPAppID = strings.TrimSpace(cfg.MPAppID) + cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret) + cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID) + cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret) + cfg.Mode = normalizeWeChatConnectMode(cfg.Mode) + cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL) + cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL) + + cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID) + cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret) + cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID) + cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret) + cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID) + cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret) + cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID) + cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret) + + if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled { + switch cfg.Mode { + case "mp": + cfg.MPEnabled = true + case "mobile": + cfg.MobileEnabled = true + default: + cfg.OpenEnabled = true + } + } + cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode) + cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode) + if cfg.FrontendRedirectURL == "" { + cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect + } +} + // TokenRefreshConfig OAuth token自动刷新配置 type TokenRefreshConfig struct { // 是否启用自动刷新 @@ -1007,6 +1253,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) + applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat) + normalizeWeChatConnectConfig(&cfg.WeChat) cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName) cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID) cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret) @@ -1024,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath) cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath) cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath) + cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE") + cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN") cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) @@ -1202,6 +1452,24 @@ func setDefaults() { viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // WeChat Connect OAuth 登录 + viper.SetDefault("wechat_connect.enabled", false) + viper.SetDefault("wechat_connect.app_id", "") + viper.SetDefault("wechat_connect.app_secret", "") + viper.SetDefault("wechat_connect.open_app_id", "") + viper.SetDefault("wechat_connect.open_app_secret", "") + viper.SetDefault("wechat_connect.mp_app_id", "") + viper.SetDefault("wechat_connect.mp_app_secret", "") + viper.SetDefault("wechat_connect.mobile_app_id", "") + viper.SetDefault("wechat_connect.mobile_app_secret", "") + viper.SetDefault("wechat_connect.open_enabled", false) + viper.SetDefault("wechat_connect.mp_enabled", false) + viper.SetDefault("wechat_connect.mobile_enabled", false) + viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode) + viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes) + viper.SetDefault("wechat_connect.redirect_url", "") + viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect) + // Generic OIDC OAuth 登录 viper.SetDefault("oidc_connect.enabled", false) viper.SetDefault("oidc_connect.provider_name", "OIDC") @@ -1217,7 +1485,7 @@ func setDefaults() { viper.SetDefault("oidc_connect.redirect_url", "") viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") - viper.SetDefault("oidc_connect.use_pkce", false) + viper.SetDefault("oidc_connect.use_pkce", true) viper.SetDefault("oidc_connect.validate_id_token", true) viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") viper.SetDefault("oidc_connect.clock_skew_seconds", 120) @@ -1407,7 +1675,7 @@ func setDefaults() { viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) @@ -1629,9 +1897,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") @@ -1662,6 +1927,45 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } + if c.WeChat.Enabled { + weChat := c.WeChat + normalizeWeChatConnectConfig(&weChat) + + if weChat.OpenEnabled { + if strings.TrimSpace(weChat.OpenAppID) == "" { + return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true") + } + if strings.TrimSpace(weChat.OpenAppSecret) == "" { + return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true") + } + } + if weChat.MPEnabled { + if strings.TrimSpace(weChat.MPAppID) == "" { + return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true") + } + if strings.TrimSpace(weChat.MPAppSecret) == "" { + return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true") + } + } + if weChat.MobileEnabled { + if strings.TrimSpace(weChat.MobileAppID) == "" { + return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true") + } + if strings.TrimSpace(weChat.MobileAppSecret) == "" { + return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true") + } + } + if v := strings.TrimSpace(weChat.RedirectURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err) + } + warnIfInsecureURL("wechat_connect.redirect_url", v) + } + if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil { + return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err) + } + warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL) + } if c.OIDC.Enabled { if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") @@ -1685,9 +1989,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.OIDC.ClientSecret) == "" { return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index cf58316c316a7297ff7b84e1e8a12f3e066fe8e3..6ba86aa1be3a9b6a316b12f900efe5d9cae15ad1 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } } +func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") + t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback") + + cfg, err := Load() + require.NoError(t, err) + require.True(t, cfg.WeChat.Enabled) + require.True(t, cfg.WeChat.OpenEnabled) + require.True(t, cfg.WeChat.MPEnabled) + require.False(t, cfg.WeChat.MobileEnabled) + require.Equal(t, "open", cfg.WeChat.Mode) + require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID) + require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret) + require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID) + require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret) + require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL) +} + +func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + require.NoError(t, err) + require.True(t, cfg.OIDC.UsePKCE) + require.True(t, cfg.OIDC.ValidateIDToken) + require.False(t, cfg.OIDC.UsePKCEExplicit) + require.False(t, cfg.OIDC.ValidateIDTokenExplicit) +} + +func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("OIDC_CONNECT_USE_PKCE", "false") + t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false") + + cfg, err := Load() + require.NoError(t, err) + require.False(t, cfg.OIDC.UsePKCE) + require.False(t, cfg.OIDC.ValidateIDToken) + require.True(t, cfg.OIDC.UsePKCEExplicit) + require.True(t, cfg.OIDC.ValidateIDTokenExplicit) +} + func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { resetViperWithJWTSecret(t) @@ -334,7 +380,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { cfg.LinuxDo.ClientSecret = "test-secret" cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" cfg.LinuxDo.TokenAuthMethod = "client_secret_post" - cfg.LinuxDo.UsePKCE = false + cfg.LinuxDo.UsePKCE = true cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" err = cfg.Validate() @@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } } -func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { +func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() @@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { cfg.LinuxDo.UsePKCE = false err = cfg.Validate() - if err == nil { - t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") - } - if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { - t.Fatalf("Validate() expected use_pkce error, got: %v", err) + if err != nil { + t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err) } } @@ -389,6 +432,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) { cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "profile email" + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err == nil { @@ -418,6 +462,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "openid email profile" cfg.OIDC.ValidateIDToken = true + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err != nil { @@ -425,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T } } +func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.OIDC.Enabled = true + cfg.OIDC.ClientID = "oidc-client" + cfg.OIDC.ClientSecret = "oidc-secret" + cfg.OIDC.IssuerURL = "https://issuer.example.com" + cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth" + cfg.OIDC.TokenURL = "https://issuer.example.com/token" + cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo" + cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" + cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" + cfg.OIDC.Scopes = "openid email profile" + cfg.OIDC.UsePKCE = false + cfg.OIDC.ValidateIDToken = false + cfg.OIDC.JWKSURL = "" + cfg.OIDC.AllowedSigningAlgs = "" + + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err) + } +} + func TestLoadDefaultDashboardCacheConfig(t *testing.T) { resetViperWithJWTSecret(t) @@ -840,6 +914,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = true if err := cfg.Validate(); err != nil { t.Fatalf("Validate() unexpected error: %v", err) @@ -990,6 +1065,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo client id required", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "" }, wantErr: "linuxdo_connect.client_id", @@ -998,6 +1074,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo token auth method", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "client" c.LinuxDo.ClientSecret = "secret" c.LinuxDo.AuthorizeURL = "https://example.com/authorize" diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index cba3ae21494bcaa5cb500a8c36911787486e005c..ddeaab0218a203abb7f00216289636773ac19bf1 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) @@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + bindBody := map[string]any{ + "provider_type": "wechat", + "provider_key": "wechat-main", + "provider_subject": "union-123", + "metadata": map[string]any{"source": "admin-repair"}, + "channel": map[string]any{ + "channel": "open", + "channel_app_id": "wx-open", + "channel_subject": "openid-123", + }, + } + body, _ := json.Marshal(bindBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} - body, _ := json.Marshal(createBody) + body, _ = json.Marshal(createBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) } +func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) { + router, adminSvc := setupAdminRouter() + + body, err := json.Marshal(map[string]any{ + "provider_type": "oidc", + "provider_key": "https://issuer.example", + "provider_subject": "subject-123", + "issuer": "https://issuer.example", + "metadata": map[string]any{"report_id": 12}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor) + require.NotNil(t, adminSvc.boundAuthIdentity) + require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType) + require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey) + require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject) + require.Nil(t, adminSvc.boundAuthIdentity.Channel) + require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"]) +} + func TestGroupHandlerEndpoints(t *testing.T) { router, _ := setupAdminRouter() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 6d1ef1b6b82426d8d2679808eb16073410cec2e2..2fe29fa36e94b689a6ed57eeee5071e974ee7845 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -17,6 +17,8 @@ type stubAdminService struct { proxies []service.Proxy proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode + boundAuthIdentity *service.AdminBindAuthIdentityInput + boundAuthIdentityFor int64 createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -42,6 +44,14 @@ type stubAdminService struct { sortOrder string calls int } + lastListUsers struct { + page int + pageSize int + filters service.UserListFilters + sortBy string + sortOrder string + calls int + } lastListProxies struct { protocol string status string @@ -127,6 +137,12 @@ func newStubAdminService() *stubAdminService { } func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) { + s.lastListUsers.page = page + s.lastListUsers.pageSize = pageSize + s.lastListUsers.filters = filters + s.lastListUsers.sortBy = sortBy + s.lastListUsers.sortOrder = sortOrder + s.lastListUsers.calls++ return s.users, int64(len(s.users)), nil } @@ -167,6 +183,63 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) { + user, err := s.GetUser(ctx, userID) + if err != nil { + return nil, err + } + return &service.UserRPMStatus{ + UserRPMUsed: 0, + UserRPMLimit: user.RPMLimit, + }, nil +} + +func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { + s.boundAuthIdentityFor = userID + copied := input + if input.Metadata != nil { + copied.Metadata = map[string]any{} + for key, value := range input.Metadata { + copied.Metadata[key] = value + } + } + if input.Channel != nil { + channel := *input.Channel + if input.Channel.Metadata != nil { + channel.Metadata = map[string]any{} + for key, value := range input.Channel.Metadata { + channel.Metadata[key] = value + } + } + copied.Channel = &channel + } + s.boundAuthIdentity = &copied + + now := time.Now().UTC() + result := &service.AdminBoundAuthIdentity{ + UserID: userID, + ProviderType: input.ProviderType, + ProviderKey: input.ProviderKey, + ProviderSubject: input.ProviderSubject, + VerifiedAt: &now, + Issuer: input.Issuer, + Metadata: input.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + if input.Channel != nil { + result.Channel = &service.AdminBoundAuthIdentityChannel{ + Channel: input.Channel.Channel, + ChannelAppID: input.Channel.ChannelAppID, + ChannelSubject: input.Channel.ChannelSubject, + Metadata: input.Channel.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + } + return result, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } @@ -214,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } +func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + return nil +} + +func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error { + return nil +} + func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { s.lastListAccounts.platform = platform s.lastListAccounts.accountType = accountType diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 9151d01872feac8da456b69ce045af7e187005dd..950e6e727456be07e44e09d0ed9396c94e9a1bb1 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse { UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } resp.BillingModelSource = ch.BillingModelSource - if resp.BillingModelSource == "" { - resp.BillingModelSource = service.BillingModelSourceChannelMapped - } if resp.GroupIDs == nil { resp.GroupIDs = []int64{} } diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index f218cce45361cfbfec2fddbfc6c9e9d59ee2d8a5..12cd4bdda1fe7777639a3d7f8386428661c7cda9 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { ch := &service.Channel{ ID: 1, Name: "ch", - BillingModelSource: "", + BillingModelSource: service.BillingModelSourceChannelMapped, CreatedAt: now, UpdatedAt: now, GroupIDs: nil, @@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { }, } + // handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底 + // 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理), + // 因此这里构造 fixture 时直接传入归一化后的值。 resp := channelToResponse(ch) require.Equal(t, "channel_mapped", resp.BillingModelSource) require.NotNil(t, resp.GroupIDs) @@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { require.Equal(t, "token", resp.ModelPricing[0].BillingMode) } +func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) { + // handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。 + ch := &service.Channel{ + ID: 1, + Name: "ch", + BillingModelSource: "", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + resp := channelToResponse(ch) + require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责") +} + func TestChannelToResponse_NilModels(t *testing.T) { now := time.Now() ch := &service.Channel{ diff --git a/backend/internal/handler/admin/channel_monitor_handler.go b/backend/internal/handler/admin/channel_monitor_handler.go new file mode 100644 index 0000000000000000000000000000000000000000..e92c81fea0155d810b512fbccd836c96f31a925d --- /dev/null +++ b/backend/internal/handler/admin/channel_monitor_handler.go @@ -0,0 +1,427 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + // monitorMaxPageSize 列表分页上限。 + monitorMaxPageSize = 100 + // monitorAPIKeyMaskPrefix 脱敏时保留的明文前缀长度。 + monitorAPIKeyMaskPrefix = 4 + // monitorAPIKeyMaskSuffix 脱敏后追加的占位字符串。 + monitorAPIKeyMaskSuffix = "***" +) + +// ChannelMonitorHandler 渠道监控管理后台 handler。 +type ChannelMonitorHandler struct { + monitorService *service.ChannelMonitorService +} + +// NewChannelMonitorHandler 创建 handler。 +func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorHandler { + return &ChannelMonitorHandler{monitorService: monitorService} +} + +// --- Request / Response --- + +type channelMonitorCreateRequest struct { + Name string `json:"name" binding:"required,max=100"` + Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"` + Endpoint string `json:"endpoint" binding:"required,max=500"` + APIKey string `json:"api_key" binding:"required,max=2000"` + PrimaryModel string `json:"primary_model" binding:"required,max=200"` + ExtraModels []string `json:"extra_models"` + GroupName string `json:"group_name" binding:"max=100"` + Enabled *bool `json:"enabled"` + IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"` + TemplateID *int64 `json:"template_id"` + ExtraHeaders map[string]string `json:"extra_headers"` + BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"` + BodyOverride map[string]any `json:"body_override"` +} + +type channelMonitorUpdateRequest struct { + Name *string `json:"name" binding:"omitempty,max=100"` + Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"` + Endpoint *string `json:"endpoint" binding:"omitempty,max=500"` + APIKey *string `json:"api_key" binding:"omitempty,max=2000"` + PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"` + ExtraModels *[]string `json:"extra_models"` + GroupName *string `json:"group_name" binding:"omitempty,max=100"` + Enabled *bool `json:"enabled"` + IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"` + TemplateID *int64 `json:"template_id"` + ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID + ExtraHeaders *map[string]string `json:"extra_headers"` + BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"` + BodyOverride *map[string]any `json:"body_override"` +} + +type channelMonitorResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Endpoint string `json:"endpoint"` + APIKeyMasked string `json:"api_key_masked"` + APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"` + PrimaryModel string `json:"primary_model"` + ExtraModels []string `json:"extra_models"` + GroupName string `json:"group_name"` + Enabled bool `json:"enabled"` + IntervalSeconds int `json:"interval_seconds"` + LastCheckedAt *string `json:"last_checked_at"` + CreatedBy int64 `json:"created_by"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + PrimaryStatus string `json:"primary_status"` + PrimaryLatencyMs *int `json:"primary_latency_ms"` + Availability7d float64 `json:"availability_7d"` + ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"` + // 请求自定义快照:前端编辑 / 展示「高级设置」用 + TemplateID *int64 `json:"template_id"` + ExtraHeaders map[string]string `json:"extra_headers"` + BodyOverrideMode string `json:"body_override_mode"` + BodyOverride map[string]any `json:"body_override"` +} + +type channelMonitorCheckResultResponse struct { + Model string `json:"model"` + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + Message string `json:"message"` + CheckedAt string `json:"checked_at"` +} + +type channelMonitorHistoryItemResponse struct { + ID int64 `json:"id"` + Model string `json:"model"` + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + Message string `json:"message"` + CheckedAt string `json:"checked_at"` +} + +// maskAPIKey 对 API Key 明文做脱敏:前 4 字符 + "***",长度 ≤ 4 时只显示 "***"。 +func maskAPIKey(plain string) string { + if len(plain) <= monitorAPIKeyMaskPrefix { + return monitorAPIKeyMaskSuffix + } + return plain[:monitorAPIKeyMaskPrefix] + monitorAPIKeyMaskSuffix +} + +func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse { + if m == nil { + return nil + } + extras := m.ExtraModels + if extras == nil { + extras = []string{} + } + headers := m.ExtraHeaders + if headers == nil { + headers = map[string]string{} + } + resp := &channelMonitorResponse{ + ID: m.ID, + Name: m.Name, + Provider: m.Provider, + Endpoint: m.Endpoint, + APIKeyMasked: maskAPIKey(m.APIKey), + APIKeyDecryptFailed: m.APIKeyDecryptFailed, + PrimaryModel: m.PrimaryModel, + ExtraModels: extras, + GroupName: m.GroupName, + Enabled: m.Enabled, + IntervalSeconds: m.IntervalSeconds, + CreatedBy: m.CreatedBy, + CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339), + UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339), + TemplateID: m.TemplateID, + ExtraHeaders: headers, + BodyOverrideMode: m.BodyOverrideMode, + BodyOverride: m.BodyOverride, + // PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。 + } + if m.LastCheckedAt != nil { + s := m.LastCheckedAt.UTC().Format(time.RFC3339) + resp.LastCheckedAt = &s + } + return resp +} + +func checkResultToResponse(r *service.CheckResult) channelMonitorCheckResultResponse { + return channelMonitorCheckResultResponse{ + Model: r.Model, + Status: r.Status, + LatencyMs: r.LatencyMs, + PingLatencyMs: r.PingLatencyMs, + Message: r.Message, + CheckedAt: r.CheckedAt.UTC().Format(time.RFC3339), + } +} + +func historyEntryToResponse(e *service.ChannelMonitorHistoryEntry) channelMonitorHistoryItemResponse { + return channelMonitorHistoryItemResponse{ + ID: e.ID, + Model: e.Model, + Status: e.Status, + LatencyMs: e.LatencyMs, + PingLatencyMs: e.PingLatencyMs, + Message: e.Message, + CheckedAt: e.CheckedAt.UTC().Format(time.RFC3339), + } +} + +// ParseChannelMonitorID 提取并校验路径参数 :id(admin 与 user handler 共享)。 +// 校验失败时已写入 4xx 响应,调用方只需 return。 +func ParseChannelMonitorID(c *gin.Context) (int64, bool) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || id <= 0 { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_MONITOR_ID", "invalid monitor id")) + return 0, false + } + return id, true +} + +// parseListEnabled 解析 enabled query 参数:true/false 转为 *bool,空或非法则返回 nil。 +func parseListEnabled(raw string) *bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "true", "1", "yes": + v := true + return &v + case "false", "0", "no": + v := false + return &v + default: + return nil + } +} + +// --- Handlers --- + +// List GET /api/v1/admin/channel-monitors +func (h *ChannelMonitorHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + if pageSize > monitorMaxPageSize { + pageSize = monitorMaxPageSize + } + + params := service.ChannelMonitorListParams{ + Page: page, + PageSize: pageSize, + Provider: strings.TrimSpace(c.Query("provider")), + Enabled: parseListEnabled(c.Query("enabled")), + Search: strings.TrimSpace(c.Query("search")), + } + + items, total, err := h.monitorService.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + + summaries := h.batchSummaryFor(c, items) + out := make([]*channelMonitorResponse, 0, len(items)) + for _, m := range items { + out = append(out, buildListItemResponse(m, summaries[m.ID])) + } + response.Paginated(c, out, total, page, pageSize) +} + +// batchSummaryFor 批量聚合 latest + 7d 可用率,避免每行 2 次 SQL(消除 N+1)。 +func (h *ChannelMonitorHandler) batchSummaryFor(c *gin.Context, items []*service.ChannelMonitor) map[int64]service.MonitorStatusSummary { + ids := make([]int64, 0, len(items)) + primaryByID := make(map[int64]string, len(items)) + extrasByID := make(map[int64][]string, len(items)) + for _, m := range items { + ids = append(ids, m.ID) + primaryByID[m.ID] = m.PrimaryModel + extrasByID[m.ID] = m.ExtraModels + } + return h.monitorService.BatchMonitorStatusSummary(c.Request.Context(), ids, primaryByID, extrasByID) +} + +// buildListItemResponse 把 monitor + summary 装成 admin list 的响应行。 +func buildListItemResponse(m *service.ChannelMonitor, summary service.MonitorStatusSummary) *channelMonitorResponse { + resp := channelMonitorToResponse(m) + resp.PrimaryStatus = summary.PrimaryStatus + resp.PrimaryLatencyMs = summary.PrimaryLatencyMs + resp.Availability7d = summary.Availability7d + resp.ExtraModelsStatus = make([]dto.ChannelMonitorExtraModelStatus, 0, len(summary.ExtraModels)) + for _, e := range summary.ExtraModels { + resp.ExtraModelsStatus = append(resp.ExtraModelsStatus, dto.ChannelMonitorExtraModelStatus{ + Model: e.Model, + Status: e.Status, + LatencyMs: e.LatencyMs, + }) + } + return resp +} + +// Get GET /api/v1/admin/channel-monitors/:id +func (h *ChannelMonitorHandler) Get(c *gin.Context) { + id, ok := ParseChannelMonitorID(c) + if !ok { + return + } + m, err := h.monitorService.Get(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, channelMonitorToResponse(m)) +} + +// Create POST /api/v1/admin/channel-monitors +func (h *ChannelMonitorHandler) Create(c *gin.Context) { + var req channelMonitorCreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + subject, _ := middleware2.GetAuthSubjectFromContext(c) + + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + + m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{ + Name: req.Name, + Provider: req.Provider, + Endpoint: req.Endpoint, + APIKey: req.APIKey, + PrimaryModel: req.PrimaryModel, + ExtraModels: req.ExtraModels, + GroupName: req.GroupName, + Enabled: enabled, + IntervalSeconds: req.IntervalSeconds, + CreatedBy: subject.UserID, + TemplateID: req.TemplateID, + ExtraHeaders: req.ExtraHeaders, + BodyOverrideMode: req.BodyOverrideMode, + BodyOverride: req.BodyOverride, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Created(c, channelMonitorToResponse(m)) +} + +// Update PUT /api/v1/admin/channel-monitors/:id +func (h *ChannelMonitorHandler) Update(c *gin.Context) { + id, ok := ParseChannelMonitorID(c) + if !ok { + return + } + var req channelMonitorUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{ + Name: req.Name, + Provider: req.Provider, + Endpoint: req.Endpoint, + APIKey: req.APIKey, + PrimaryModel: req.PrimaryModel, + ExtraModels: req.ExtraModels, + GroupName: req.GroupName, + Enabled: req.Enabled, + IntervalSeconds: req.IntervalSeconds, + TemplateID: req.TemplateID, + ClearTemplate: req.ClearTemplate, + ExtraHeaders: req.ExtraHeaders, + BodyOverrideMode: req.BodyOverrideMode, + BodyOverride: req.BodyOverride, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, channelMonitorToResponse(m)) +} + +// Delete DELETE /api/v1/admin/channel-monitors/:id +func (h *ChannelMonitorHandler) Delete(c *gin.Context) { + id, ok := ParseChannelMonitorID(c) + if !ok { + return + } + if err := h.monitorService.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) +} + +// Run POST /api/v1/admin/channel-monitors/:id/run +func (h *ChannelMonitorHandler) Run(c *gin.Context) { + id, ok := ParseChannelMonitorID(c) + if !ok { + return + } + results, err := h.monitorService.RunCheck(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]channelMonitorCheckResultResponse, 0, len(results)) + for _, r := range results { + out = append(out, checkResultToResponse(r)) + } + response.Success(c, gin.H{"results": out}) +} + +// History GET /api/v1/admin/channel-monitors/:id/history +func (h *ChannelMonitorHandler) History(c *gin.Context) { + id, ok := ParseChannelMonitorID(c) + if !ok { + return + } + limit := parseHistoryLimit(c.Query("limit")) + model := strings.TrimSpace(c.Query("model")) + + entries, err := h.monitorService.ListHistory(c.Request.Context(), id, model, limit) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]channelMonitorHistoryItemResponse, 0, len(entries)) + for _, e := range entries { + out = append(out, historyEntryToResponse(e)) + } + response.Success(c, gin.H{"items": out}) +} + +// parseHistoryLimit 解析 history 接口的 limit query。 +// 使用 service 包的统一上下限常量,避免在 handler 重复定义同名魔法值。 +func parseHistoryLimit(raw string) int { + if strings.TrimSpace(raw) == "" { + return service.MonitorHistoryDefaultLimit + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + return service.MonitorHistoryDefaultLimit + } + if v > service.MonitorHistoryMaxLimit { + return service.MonitorHistoryMaxLimit + } + return v +} diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go new file mode 100644 index 0000000000000000000000000000000000000000..bebe092907e614c799e02c818f0f7def7b933555 --- /dev/null +++ b/backend/internal/handler/admin/channel_monitor_template_handler.go @@ -0,0 +1,234 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。 +type ChannelMonitorRequestTemplateHandler struct { + templateService *service.ChannelMonitorRequestTemplateService +} + +// NewChannelMonitorRequestTemplateHandler 创建 handler。 +func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler { + return &ChannelMonitorRequestTemplateHandler{templateService: templateService} +} + +// --- DTO --- + +type channelMonitorTemplateCreateRequest struct { + Name string `json:"name" binding:"required,max=100"` + Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"` + Description string `json:"description" binding:"max=500"` + ExtraHeaders map[string]string `json:"extra_headers"` + BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"` + BodyOverride map[string]any `json:"body_override"` +} + +type channelMonitorTemplateUpdateRequest struct { + Name *string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description" binding:"omitempty,max=500"` + ExtraHeaders *map[string]string `json:"extra_headers"` + BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"` + BodyOverride *map[string]any `json:"body_override"` +} + +type channelMonitorTemplateResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Description string `json:"description"` + ExtraHeaders map[string]string `json:"extra_headers"` + BodyOverrideMode string `json:"body_override_mode"` + BodyOverride map[string]any `json:"body_override"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + AssociatedMonitors int64 `json:"associated_monitors"` +} + +func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse { + if t == nil { + return nil + } + headers := t.ExtraHeaders + if headers == nil { + headers = map[string]string{} + } + count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID) + return &channelMonitorTemplateResponse{ + ID: t.ID, + Name: t.Name, + Provider: t.Provider, + Description: t.Description, + ExtraHeaders: headers, + BodyOverrideMode: t.BodyOverrideMode, + BodyOverride: t.BodyOverride, + CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339), + UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339), + AssociatedMonitors: count, + } +} + +// parseTemplateID 提取并校验 :id。 +func parseTemplateID(c *gin.Context) (int64, bool) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil || id <= 0 { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id")) + return 0, false + } + return id, true +} + +// --- Handlers --- + +// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic +func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) { + items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{ + Provider: strings.TrimSpace(c.Query("provider")), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]*channelMonitorTemplateResponse, 0, len(items)) + for _, t := range items { + out = append(out, h.toResponse(c, t)) + } + response.Success(c, gin.H{"items": out}) +} + +// Get GET /api/v1/admin/channel-monitor-templates/:id +func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) { + id, ok := parseTemplateID(c) + if !ok { + return + } + t, err := h.templateService.Get(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, h.toResponse(c, t)) +} + +// Create POST /api/v1/admin/channel-monitor-templates +func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) { + var req channelMonitorTemplateCreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{ + Name: req.Name, + Provider: req.Provider, + Description: req.Description, + ExtraHeaders: req.ExtraHeaders, + BodyOverrideMode: req.BodyOverrideMode, + BodyOverride: req.BodyOverride, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Created(c, h.toResponse(c, t)) +} + +// Update PUT /api/v1/admin/channel-monitor-templates/:id +func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) { + id, ok := parseTemplateID(c) + if !ok { + return + } + var req channelMonitorTemplateUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{ + Name: req.Name, + Description: req.Description, + ExtraHeaders: req.ExtraHeaders, + BodyOverrideMode: req.BodyOverrideMode, + BodyOverride: req.BodyOverride, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, h.toResponse(c, t)) +} + +// Delete DELETE /api/v1/admin/channel-monitor-templates/:id +func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) { + id, ok := parseTemplateID(c) + if !ok { + return + } + if err := h.templateService.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) +} + +type channelMonitorTemplateApplyRequest struct { + // MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。 + // 仅当对应监控当前 template_id == :id 时才会真的被覆盖。 + MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"` +} + +// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply +// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。 +func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) { + id, ok := parseTemplateID(c) + if !ok { + return + } + var req channelMonitorTemplateApplyRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"affected": affected}) +} + +type associatedMonitorBriefResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` +} + +// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors +// 列出关联监控(picker 弹窗用)。 +func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) { + id, ok := parseTemplateID(c) + if !ok { + return + } + items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]associatedMonitorBriefResponse, 0, len(items)) + for _, m := range items { + out = append(out, associatedMonitorBriefResponse{ + ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled, + }) + } + response.Success(c, gin.H{"items": out}) +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index cb2bd2018e5dff0cbdb23e06d308a8a5bb16e98d..65e5ec7802b429ebe807bebd1ead622042e8cfc5 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -110,6 +110,8 @@ type CreateGroupRequest struct { RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制) + RPMLimit int `json:"rpm_limit"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -145,6 +147,8 @@ type UpdateGroupRequest struct { RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动 + RPMLimit *int `json:"rpm_limit"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) { response.Success(c, gin.H{"message": "Rate multipliers updated successfully"}) } +// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request +type BatchSetGroupRPMOverridesRequest struct { + Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"` +} + +// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group +// PUT /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + var req BatchSetGroupRPMOverridesRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides updated successfully"}) +} + +// ClearGroupRPMOverrides handles clearing all rpm_override for a group +// DELETE /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides cleared successfully"}) +} + // UpdateSortOrderRequest represents the request to update group sort orders type UpdateSortOrderRequest struct { Updates []struct { diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go index b0ed6aed8f7fca88c0ed58d7b9744958e70a6766..84359cd93bdecfae7d562a57a29655109a2ddd2a 100644 --- a/backend/internal/handler/admin/payment_handler.go +++ b/backend/internal/handler/admin/payment_handler.go @@ -3,6 +3,7 @@ package admin import ( "strconv" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Paginated(c, orders, int64(total), page, pageSize) + response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize) } // GetOrderDetail returns detailed information about a single order. @@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) { return } auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID) - response.Success(c, gin.H{"order": order, "auditLogs": auditLogs}) + response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs}) } // CancelOrder cancels a pending order (admin). @@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) { response.Success(c, gin.H{"message": "fulfillment retried"}) } +func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder { + if len(orders) == 0 { + return orders + } + out := make([]*dbent.PaymentOrder, 0, len(orders)) + for _, order := range orders { + out = append(out, sanitizeAdminPaymentOrderForResponse(order)) + } + return out +} + +func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder { + if order == nil { + return nil + } + cloned := *order + cloned.ProviderSnapshot = nil + return &cloned +} + // AdminProcessRefundRequest is the request body for admin refund processing. type AdminProcessRefundRequest struct { Amount float64 `json:"amount"` diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index bec0f126137a2b05ca7acde898e4c1b6d21943ad..4277f0f12bc21dfa61a323c964aad48effe37e37 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -43,6 +43,15 @@ func scopesContainOpenID(scopes string) bool { return false } +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -73,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) @@ -93,114 +107,142 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { paymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - FrontendURL: settings.FrontendURL, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPasswordConfigured: settings.SMTPPasswordConfigured, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, - LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, - LinuxDoConnectClientID: settings.LinuxDoConnectClientID, - LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, - LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, - OIDCConnectEnabled: settings.OIDCConnectEnabled, - OIDCConnectProviderName: settings.OIDCConnectProviderName, - OIDCConnectClientID: settings.OIDCConnectClientID, - OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured, - OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL, - OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL, - OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL, - OIDCConnectTokenURL: settings.OIDCConnectTokenURL, - OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL, - OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL, - OIDCConnectScopes: settings.OIDCConnectScopes, - OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL, - OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL, - OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken, - OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs, - OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds, - OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified, - OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, - OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, - OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - TableDefaultPageSize: settings.TableDefaultPageSize, - TablePageSizeOptions: settings.TablePageSizeOptions, - CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), - CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, - OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, - OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, - OpsQueryModeDefault: settings.OpsQueryModeDefault, - OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: settings.MinClaudeCodeVersion, - MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion, - AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, - BackendModeEnabled: settings.BackendModeEnabled, - EnableFingerprintUnification: settings.EnableFingerprintUnification, - EnableMetadataPassthrough: settings.EnableMetadataPassthrough, - EnableCCHSigning: settings.EnableCCHSigning, - WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, - BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, - BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, - AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, - AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), - PaymentEnabled: paymentCfg.Enabled, - PaymentMinAmount: paymentCfg.MinAmount, - PaymentMaxAmount: paymentCfg.MaxAmount, - PaymentDailyLimit: paymentCfg.DailyLimit, - PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin, - PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, - PaymentEnabledTypes: paymentCfg.EnabledTypes, - PaymentBalanceDisabled: paymentCfg.BalanceDisabled, - PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier, - PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate, - PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, - PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, - PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, - PaymentHelpImageURL: paymentCfg.HelpImageURL, - PaymentHelpText: paymentCfg.HelpText, - PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled, - PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax, - PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, - PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, - PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - }) + payload := dto.SystemSettings{ + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + FrontendURL: settings.FrontendURL, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: settings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: settings.WeChatConnectEnabled, + WeChatConnectAppID: settings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured, + WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured, + WeChatConnectMPAppID: settings.WeChatConnectMPAppID, + WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured, + WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured, + WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled, + WeChatConnectMode: settings.WeChatConnectMode, + WeChatConnectScopes: settings.WeChatConnectScopes, + WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL, + OIDCConnectEnabled: settings.OIDCConnectEnabled, + OIDCConnectProviderName: settings.OIDCConnectProviderName, + OIDCConnectClientID: settings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: settings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL, + OIDCConnectScopes: settings.OIDCConnectScopes, + OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + TableDefaultPageSize: settings.TableDefaultPageSize, + TablePageSizeOptions: settings.TablePageSizeOptions, + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + DefaultUserRPMLimit: settings.DefaultUserRPMLimit, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, + OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, + OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, + OpsQueryModeDefault: settings.OpsQueryModeDefault, + OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: settings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion, + AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, + BackendModeEnabled: settings.BackendModeEnabled, + EnableFingerprintUnification: settings.EnableFingerprintUnification, + EnableMetadataPassthrough: settings.EnableMetadataPassthrough, + EnableCCHSigning: settings.EnableCCHSigning, + WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), + PaymentEnabled: paymentCfg.Enabled, + PaymentMinAmount: paymentCfg.MinAmount, + PaymentMaxAmount: paymentCfg.MaxAmount, + PaymentDailyLimit: paymentCfg.DailyLimit, + PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin, + PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, + PaymentEnabledTypes: paymentCfg.EnabledTypes, + PaymentBalanceDisabled: paymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate, + PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, + PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, + PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, + PaymentHelpImageURL: paymentCfg.HelpImageURL, + PaymentHelpText: paymentCfg.HelpText, + PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled, + PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax, + PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, + PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, + PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, + + ChannelMonitorEnabled: settings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, + + AvailableChannelsEnabled: settings.AvailableChannelsEnabled, + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } // UpdateSettingsRequest 更新设置请求 @@ -235,6 +277,24 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecret string `json:"wechat_connect_app_secret"` + WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` + WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"` + WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` + WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"` + WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` + WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"` + WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` + WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` + WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` @@ -250,8 +310,8 @@ type UpdateSettingsRequest struct { OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"` OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"` OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"` - OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"` - OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"` + OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"` + OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"` OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"` OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"` OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"` @@ -276,9 +336,31 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -311,6 +393,15 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"` + // Balance low notification BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` @@ -341,6 +432,13 @@ type UpdateSettingsRequest struct { PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"` PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"` PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"` + + // Channel Monitor feature switch + ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"` + + // Available Channels feature switch (user-facing) + AvailableChannelsEnabled *bool `json:"available_channels_enabled"` } // UpdateSettings 更新系统设置 @@ -357,6 +455,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // 验证参数 if req.DefaultConcurrency < 1 { @@ -381,6 +484,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions) + req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) + req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) + req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -459,7 +566,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + if req.WeChatConnectEnabled { + req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID) + req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret) + req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID) + req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret) + req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID) + req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret) + req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID) + req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret) + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode)) + req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes) + req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL) + req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL)) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL)) + if req.WeChatConnectMode == "" { + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode)) + } + if req.WeChatConnectScopes == "" { + req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes) + } + + if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled { + response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time") + return + } + if req.WeChatConnectMode != "" { + switch req.WeChatConnectMode { + case "open", "mp", "mobile": + default: + response.BadRequest(c, "WeChat mode must be open, mp, or mobile") + return + } + } + if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled { + switch req.WeChatConnectMode { + case "mp": + req.WeChatConnectMPEnabled = true + case "mobile": + req.WeChatConnectMobileEnabled = true + default: + req.WeChatConnectOpenEnabled = true + } + } + if req.WeChatConnectMode == "" { + if req.WeChatConnectMPEnabled { + req.WeChatConnectMode = "mp" + } else if req.WeChatConnectMobileEnabled { + req.WeChatConnectMode = "mobile" + } else { + req.WeChatConnectMode = "open" + } + } + + req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID)) + req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID)) + + if req.WeChatConnectOpenAppSecret == "" { + req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectMPAppSecret == "" { + req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectMobileAppSecret == "" { + req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret)) + } + if req.WeChatConnectAppSecret == "" { + req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret)) + } + + if req.WeChatConnectOpenEnabled { + if req.WeChatConnectOpenAppID == "" { + response.BadRequest(c, "WeChat PC App ID is required when enabled") + return + } + if req.WeChatConnectOpenAppSecret == "" { + response.BadRequest(c, "WeChat PC App Secret is required when enabled") + return + } + } + if req.WeChatConnectMPEnabled { + if req.WeChatConnectMPAppID == "" { + response.BadRequest(c, "WeChat Official Account App ID is required when enabled") + return + } + if req.WeChatConnectMPAppSecret == "" { + response.BadRequest(c, "WeChat Official Account App Secret is required when enabled") + return + } + } + if req.WeChatConnectMobileEnabled { + if req.WeChatConnectMobileAppID == "" { + response.BadRequest(c, "WeChat Mobile App ID is required when enabled") + return + } + if req.WeChatConnectMobileAppSecret == "" { + response.BadRequest(c, "WeChat Mobile App Secret is required when enabled") + return + } + } + + if req.WeChatConnectScopes == "" { + if req.WeChatConnectMPEnabled { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp") + } else { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode) + } + } + if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled { + if req.WeChatConnectRedirectURL == "" { + response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { + response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") + return + } + if req.WeChatConnectFrontendRedirectURL == "" { + req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" + } + if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") + return + } + } + } + // Generic OIDC 参数验证 + oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID) @@ -478,10 +719,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath) req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath) req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath) - - if req.OIDCConnectProviderName == "" { - req.OIDCConnectProviderName = "OIDC" + req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC")) + req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID)) + req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL)) + req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL)) + req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL)) + req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL)) + req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL)) + req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL)) + req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile")) + req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL)) + req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback")) + req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post"))) + req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256")) + req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath)) + req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath)) + req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath)) + if req.OIDCConnectUsePKCE != nil { + oidcUsePKCE = *req.OIDCConnectUsePKCE } + if req.OIDCConnectValidateIDToken != nil { + oidcValidateIDToken = *req.OIDCConnectValidateIDToken + } + if req.OIDCConnectClockSkewSeconds == 0 { + req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds + if req.OIDCConnectClockSkewSeconds == 0 { + req.OIDCConnectClockSkewSeconds = 120 + } + } + if req.OIDCConnectClientID == "" { response.BadRequest(c, "OIDC Client ID is required when enabled") return @@ -544,19 +810,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") return } - if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") - return - } if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken { - if req.OIDCConnectAllowedSigningAlgs == "" { - response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") - return - } + if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return } if req.OIDCConnectJWKSURL != "" { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { @@ -805,6 +1065,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: req.WeChatConnectEnabled, + WeChatConnectAppID: req.WeChatConnectAppID, + WeChatConnectAppSecret: req.WeChatConnectAppSecret, + WeChatConnectOpenAppID: req.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret, + WeChatConnectMPAppID: req.WeChatConnectMPAppID, + WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret, + WeChatConnectMobileAppID: req.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret, + WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: req.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled, + WeChatConnectMode: req.WeChatConnectMode, + WeChatConnectScopes: req.WeChatConnectScopes, + WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: req.OIDCConnectEnabled, OIDCConnectProviderName: req.OIDCConnectProviderName, OIDCConnectClientID: req.OIDCConnectClientID, @@ -819,8 +1095,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OIDCConnectRedirectURL: req.OIDCConnectRedirectURL, OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL, OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: req.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken, + OIDCConnectUsePKCE: oidcUsePKCE, + OIDCConnectValidateIDToken: oidcValidateIDToken, OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs, OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds, OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified, @@ -843,6 +1119,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, @@ -897,6 +1174,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + PaymentVisibleMethodAlipaySource: func() string { + if req.PaymentVisibleMethodAlipaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) + } + return previousSettings.PaymentVisibleMethodAlipaySource + }(), + PaymentVisibleMethodWxpaySource: func() string { + if req.PaymentVisibleMethodWxpaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource) + } + return previousSettings.PaymentVisibleMethodWxpaySource + }(), + PaymentVisibleMethodAlipayEnabled: func() bool { + if req.PaymentVisibleMethodAlipayEnabled != nil { + return *req.PaymentVisibleMethodAlipayEnabled + } + return previousSettings.PaymentVisibleMethodAlipayEnabled + }(), + PaymentVisibleMethodWxpayEnabled: func() bool { + if req.PaymentVisibleMethodWxpayEnabled != nil { + return *req.PaymentVisibleMethodWxpayEnabled + } + return previousSettings.PaymentVisibleMethodWxpayEnabled + }(), + OpenAIAdvancedSchedulerEnabled: func() bool { + if req.OpenAIAdvancedSchedulerEnabled != nil { + return *req.OpenAIAdvancedSchedulerEnabled + } + return previousSettings.OpenAIAdvancedSchedulerEnabled + }(), BalanceLowNotifyEnabled: func() bool { if req.BalanceLowNotifyEnabled != nil { return *req.BalanceLowNotifyEnabled @@ -927,9 +1234,58 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.AccountQuotaNotifyEmails }(), + ChannelMonitorEnabled: func() bool { + if req.ChannelMonitorEnabled != nil { + return *req.ChannelMonitorEnabled + } + return previousSettings.ChannelMonitorEnabled + }(), + ChannelMonitorDefaultIntervalSeconds: func() int { + if req.ChannelMonitorDefaultIntervalSeconds != nil { + return *req.ChannelMonitorDefaultIntervalSeconds + } + return previousSettings.ChannelMonitorDefaultIntervalSeconds + }(), + AvailableChannelsEnabled: func() bool { + if req.AvailableChannelsEnabled != nil { + return *req.AvailableChannelsEnabled + } + return previousSettings.AvailableChannelsEnabled + }(), } - if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { + authSourceDefaults := &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + }, + LinuxDo: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + }, + OIDC: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + }, + WeChat: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + }, + ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), + } + if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil { response.ErrorFrom(c, err) return } @@ -969,7 +1325,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } - h.auditSettingsUpdate(c, previousSettings, settings, req) + h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req) // 重新获取设置返回 updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) @@ -977,6 +1333,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) for _, sub := range updatedSettings.DefaultSubscriptions { updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ @@ -994,113 +1355,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { updatedPaymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: updatedSettings.PromoCodeEnabled, - PasswordResetEnabled: updatedSettings.PasswordResetEnabled, - FrontendURL: updatedSettings.FrontendURL, - InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, - TotpEnabled: updatedSettings.TotpEnabled, - TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, - LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, - LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, - LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, - LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, - OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, - OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, - OIDCConnectClientID: updatedSettings.OIDCConnectClientID, - OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured, - OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL, - OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL, - OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL, - OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL, - OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL, - OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL, - OIDCConnectScopes: updatedSettings.OIDCConnectScopes, - OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL, - OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL, - OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE, - OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken, - OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs, - OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds, - OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified, - OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, - OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, - OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - HomeContent: updatedSettings.HomeContent, - HideCcsImportButton: updatedSettings.HideCcsImportButton, - PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, - TableDefaultPageSize: updatedSettings.TableDefaultPageSize, - TablePageSizeOptions: updatedSettings.TablePageSizeOptions, - CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), - CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - DefaultSubscriptions: updatedDefaultSubscriptions, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, - OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, - OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, - OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, - OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, - MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion, - AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, - BackendModeEnabled: updatedSettings.BackendModeEnabled, - EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, - EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, - EnableCCHSigning: updatedSettings.EnableCCHSigning, - BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, - BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, - AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, - AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), - PaymentEnabled: updatedPaymentCfg.Enabled, - PaymentMinAmount: updatedPaymentCfg.MinAmount, - PaymentMaxAmount: updatedPaymentCfg.MaxAmount, - PaymentDailyLimit: updatedPaymentCfg.DailyLimit, - PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin, - PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, - PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, - PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, - PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier, - PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate, - PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, - PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, - PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, - PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL, - PaymentHelpText: updatedPaymentCfg.HelpText, - PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled, - PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax, - PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, - PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, - PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - }) + payload := dto.SystemSettings{ + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: updatedSettings.PromoCodeEnabled, + PasswordResetEnabled: updatedSettings.PasswordResetEnabled, + FrontendURL: updatedSettings.FrontendURL, + InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, + TotpEnabled: updatedSettings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled, + WeChatConnectAppID: updatedSettings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured, + WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured, + WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID, + WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured, + WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured, + WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled, + WeChatConnectMode: updatedSettings.WeChatConnectMode, + WeChatConnectScopes: updatedSettings.WeChatConnectScopes, + WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL, + OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, + OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, + OIDCConnectClientID: updatedSettings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL, + OIDCConnectScopes: updatedSettings.OIDCConnectScopes, + OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + HomeContent: updatedSettings.HomeContent, + HideCcsImportButton: updatedSettings.HideCcsImportButton, + PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + TableDefaultPageSize: updatedSettings.TableDefaultPageSize, + TablePageSizeOptions: updatedSettings.TablePageSizeOptions, + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, + DefaultSubscriptions: updatedDefaultSubscriptions, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, + OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, + OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, + OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, + MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion, + AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, + BackendModeEnabled: updatedSettings.BackendModeEnabled, + EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, + EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, + EnableCCHSigning: updatedSettings.EnableCCHSigning, + PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled, + BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), + PaymentEnabled: updatedPaymentCfg.Enabled, + PaymentMinAmount: updatedPaymentCfg.MinAmount, + PaymentMaxAmount: updatedPaymentCfg.MaxAmount, + PaymentDailyLimit: updatedPaymentCfg.DailyLimit, + PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin, + PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, + PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, + PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate, + PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, + PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, + PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, + PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL, + PaymentHelpText: updatedPaymentCfg.HelpText, + PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled, + PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax, + PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, + PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, + PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, + + ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds, + + AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, + } + response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } // hasPaymentFields returns true if any payment-related field was explicitly provided. @@ -1117,12 +1506,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool { req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil } -func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) { +func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) { if before == nil || after == nil { return } - changed := diffSettings(before, after, req) + changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req) if len(changed) == 0 { return } @@ -1137,7 +1526,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys ) } -func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string { +func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string { changed := make([]string, 0, 20) if before.RegistrationEnabled != after.RegistrationEnabled { changed = append(changed, "registration_enabled") @@ -1205,6 +1594,54 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.WeChatConnectEnabled != after.WeChatConnectEnabled { + changed = append(changed, "wechat_connect_enabled") + } + if before.WeChatConnectAppID != after.WeChatConnectAppID { + changed = append(changed, "wechat_connect_app_id") + } + if req.WeChatConnectAppSecret != "" { + changed = append(changed, "wechat_connect_app_secret") + } + if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID { + changed = append(changed, "wechat_connect_open_app_id") + } + if req.WeChatConnectOpenAppSecret != "" { + changed = append(changed, "wechat_connect_open_app_secret") + } + if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID { + changed = append(changed, "wechat_connect_mp_app_id") + } + if req.WeChatConnectMPAppSecret != "" { + changed = append(changed, "wechat_connect_mp_app_secret") + } + if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID { + changed = append(changed, "wechat_connect_mobile_app_id") + } + if req.WeChatConnectMobileAppSecret != "" { + changed = append(changed, "wechat_connect_mobile_app_secret") + } + if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled { + changed = append(changed, "wechat_connect_open_enabled") + } + if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled { + changed = append(changed, "wechat_connect_mp_enabled") + } + if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled { + changed = append(changed, "wechat_connect_mobile_enabled") + } + if before.WeChatConnectMode != after.WeChatConnectMode { + changed = append(changed, "wechat_connect_mode") + } + if before.WeChatConnectScopes != after.WeChatConnectScopes { + changed = append(changed, "wechat_connect_scopes") + } + if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL { + changed = append(changed, "wechat_connect_redirect_url") + } + if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL { + changed = append(changed, "wechat_connect_frontend_redirect_url") + } if before.OIDCConnectEnabled != after.OIDCConnectEnabled { changed = append(changed, "oidc_connect_enabled") } @@ -1376,6 +1813,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { + changed = append(changed, "payment_visible_method_alipay_source") + } + if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource { + changed = append(changed, "payment_visible_method_wxpay_source") + } + if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled { + changed = append(changed, "payment_visible_method_alipay_enabled") + } + if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled { + changed = append(changed, "payment_visible_method_wxpay_enabled") + } + if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled { + changed = append(changed, "openai_advanced_scheduler_enabled") + } // Balance & quota notification if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { changed = append(changed, "balance_low_notify_enabled") @@ -1392,6 +1844,59 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) { changed = append(changed, "account_quota_notify_emails") } + if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled { + changed = append(changed, "channel_monitor_enabled") + } + if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds { + changed = append(changed, "channel_monitor_default_interval_seconds") + } + if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled { + changed = append(changed, "available_channels_enabled") + } + changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) + return changed +} + +func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string { + if before == nil { + before = &service.AuthSourceDefaultSettings{} + } + if after == nil { + after = &service.AuthSourceDefaultSettings{} + } + + type providerDefaultGrantField struct { + name string + before service.ProviderDefaultGrantSettings + after service.ProviderDefaultGrantSettings + } + + fields := []providerDefaultGrantField{ + {name: "email", before: before.Email, after: after.Email}, + {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo}, + {name: "oidc", before: before.OIDC, after: after.OIDC}, + {name: "wechat", before: before.WeChat, after: after.WeChat}, + } + for _, field := range fields { + if field.before.Balance != field.after.Balance { + changed = append(changed, "auth_source_default_"+field.name+"_balance") + } + if field.before.Concurrency != field.after.Concurrency { + changed = append(changed, "auth_source_default_"+field.name+"_concurrency") + } + if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) { + changed = append(changed, "auth_source_default_"+field.name+"_subscriptions") + } + if field.before.GrantOnSignup != field.after.GrantOnSignup { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup") + } + if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind { + changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind") + } + } + if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup { + changed = append(changed, "force_email_on_third_party_signup") + } return changed } @@ -1412,6 +1917,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting { + if input == nil { + return nil + } + normalized := normalizeDefaultSubscriptions(*input) + return &normalized +} + +func float64ValueOrDefault(value *float64, fallback float64) float64 { + if value == nil { + return fallback + } + return *value +} + +func intValueOrDefault(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func boolValueOrDefault(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting { + if input == nil { + return fallback + } + result := make([]service.DefaultSubscriptionSetting, 0, len(*input)) + for _, item := range *input { + result = append(result, service.DefaultSubscriptionSetting{ + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + }) + } + return result +} + +func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { + data := make(map[string]any) + raw, err := json.Marshal(settings) + if err == nil { + _ = json.Unmarshal(raw, &data) + } + if authSourceDefaults == nil { + authSourceDefaults = &service.AuthSourceDefaultSettings{} + } + + data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance + data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency + data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions + data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup + data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind + data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance + data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency + data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions + data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup + data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance + data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency + data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions + data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup + data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind + data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance + data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency + data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions + data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup + data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup + + return data +} + func equalStringSlice(a, b []string) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a33a93aad7d9b220e6b154b489b173693d299df --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -0,0 +1,503 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerRepoStub struct { + values map[string]string + lastUpdates map[string]string +} + +func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.lastUpdates = make(map[string]string, len(settings)) + for key, value := range settings { + s.lastUpdates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type failingAuthSourceSettingsRepoStub struct { + values map[string]string + err error +} + +func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok { + return s.err + } + for key, value := range settings { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + + handler.GetSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 9.5, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) + + subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any) + require.True(t, ok) + require.Len(t, subscriptions, 1) +} + +func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions]) + require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 12.75, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) +} + +func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "easypay", + "payment_visible_method_wxpay_source": "wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"]) + require.Equal(t, true, data["payment_visible_method_alipay_enabled"]) + require.Equal(t, false, data["payment_visible_method_wxpay_enabled"]) + require.Equal(t, true, data["openai_advanced_scheduler_enabled"]) +} + +func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodAlipaySource: "", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + service.SettingPaymentVisibleMethodWxpaySource: "", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": false, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) +} + +func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "oidc-client", + service.SettingKeyOIDCConnectClientSecret: "oidc-secret", + service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com", + service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth", + service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token", + service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo", + service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectUsePKCE: "true", + service.SettingKeyOIDCConnectValidateIDToken: "true", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "oidc_connect_enabled": true, + "oidc_connect_use_pkce": false, + "oidc_connect_validate_id_token": false, + "oidc_connect_allowed_signing_algs": "", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE]) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["oidc_connect_use_pkce"]) + require.Equal(t, false, data["oidc_connect_validate_id_token"]) +} + +func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "oidc-client", + service.SettingKeyOIDCConnectClientSecret: "oidc-secret", + service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com", + service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth", + service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token", + service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo", + service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + service.SettingKeyOIDCConnectRequireEmailVerified: "false", + service.SettingKeyOIDCConnectUserInfoEmailPath: "", + service.SettingKeyOIDCConnectUserInfoIDPath: "", + service.SettingKeyOIDCConnectUserInfoUsernamePath: "", + }, + } + svc := service.NewSettingService(repo, &config.Config{ + Default: config.DefaultConfig{UserConcurrency: 5}, + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + Scopes: "openid email profile", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + }) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "oidc_connect_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE]) + require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken]) +} + +func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "bogus", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource) +} + +func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &failingAuthSourceSettingsRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + }, + err: errors.New("write auth source defaults failed"), + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled]) + require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) +} + +func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) { + changed := diffSettings( + &service.SystemSettings{}, + &service.SystemSettings{}, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 0, + Concurrency: 5, + Subscriptions: nil, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: false, + }, + &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: 12.5, + Concurrency: 7, + Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + ForceEmailOnThirdPartySignup: true, + }, + UpdateSettingsRequest{}, + ) + + require.Contains(t, changed, "auth_source_default_email_balance") + require.Contains(t, changed, "auth_source_default_email_concurrency") + require.Contains(t, changed, "auth_source_default_email_subscriptions") + require.Contains(t, changed, "auth_source_default_email_grant_on_signup") + require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind") + require.Contains(t, changed, "force_email_on_third_party_signup") +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1453bd0739ddf6a518e010db4f3256fcdc8692fe..3d80107fed58a86a4eda7de07b7e3070debc6527 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -40,6 +40,7 @@ type CreateUserRequest struct { Notes string `json:"notes"` Balance float64 `json:"balance"` Concurrency int `json:"concurrency"` + RPMLimit int `json:"rpm_limit"` AllowedGroups []int64 `json:"allowed_groups"` } @@ -52,6 +53,7 @@ type UpdateUserRequest struct { Notes *string `json:"notes"` Balance *float64 `json:"balance"` Concurrency *int `json:"concurrency"` + RPMLimit *int `json:"rpm_limit"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 @@ -66,6 +68,22 @@ type UpdateBalanceRequest struct { Notes string `json:"notes"` } +type BindUserAuthIdentityRequest struct { + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + Issuer *string `json:"issuer"` + Metadata map[string]any `json:"metadata"` + Channel *BindUserAuthIdentityChannelRequest `json:"channel"` +} + +type BindUserAuthIdentityChannelRequest struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -172,6 +190,45 @@ func (h *UserHandler) GetByID(c *gin.Context) { response.Success(c, dto.UserFromServiceAdmin(user)) } +// BindAuthIdentity manually binds a canonical auth identity to a user. +// POST /api/v1/admin/users/:id/auth-identities +func (h *UserHandler) BindAuthIdentity(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req BindUserAuthIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := service.AdminBindAuthIdentityInput{ + ProviderType: req.ProviderType, + ProviderKey: req.ProviderKey, + ProviderSubject: req.ProviderSubject, + Issuer: req.Issuer, + Metadata: req.Metadata, + } + if req.Channel != nil { + input.Channel = &service.AdminBindAuthIdentityChannelInput{ + Channel: req.Channel.Channel, + ChannelAppID: req.Channel.ChannelAppID, + ChannelSubject: req.Channel.ChannelSubject, + Metadata: req.Channel.Metadata, + } + } + + result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { @@ -188,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, AllowedGroups: req.AllowedGroups, }) if err != nil { @@ -221,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, Status: req.Status, AllowedGroups: req.AllowedGroups, GroupRates: req.GroupRates, @@ -400,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) { "migrated_keys": result.MigratedKeys, }) } + +// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量 +// GET /api/v1/admin/users/:id/rpm-status +func (h *UserHandler) GetUserRPMStatus(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, status) +} diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bfba2408035a9766e8b2ba1235068b7127280052 --- /dev/null +++ b/backend/internal/handler/admin/user_handler_activity_test.go @@ -0,0 +1,114 @@ +//go:build unit + +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) { + gin.SetMode(gin.TestMode) + + lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(30 * time.Minute) + lastUsedAt := lastLoginAt.Add(90 * time.Minute) + + adminSvc := newStubAdminService() + adminSvc.users = []service.User{ + { + ID: 7, + Email: "activity@example.com", + Username: "activity-user", + Role: service.RoleUser, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + CreatedAt: lastLoginAt.Add(-24 * time.Hour), + UpdatedAt: lastLoginAt, + }, + } + handler := NewUserHandler(adminSvc, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity", + nil, + ) + + handler.List(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy) + require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder) + require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search) + + var resp struct { + Code int `json:"code"` + Data struct { + Items []struct { + LastActiveAt *time.Time `json:"last_active_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } `json:"items"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Items, 1) + require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second) +} + +func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(30 * time.Minute) + lastUsedAt := lastLoginAt.Add(90 * time.Minute) + + adminSvc := newStubAdminService() + adminSvc.users = []service.User{ + { + ID: 8, + Email: "detail@example.com", + Username: "detail-user", + Role: service.RoleUser, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + CreatedAt: lastLoginAt.Add(-24 * time.Hour), + UpdatedAt: lastLoginAt, + }, + } + handler := NewUserHandler(adminSvc, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Params = gin.Params{{Key: "id", Value: "8"}} + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil) + + handler.GetByID(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + LastActiveAt *time.Time `json:"last_active_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second) +} diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cb3e4ba596ab9b41314900b967e11ffadfec6673 --- /dev/null +++ b/backend/internal/handler/auth_current_user_test.go @@ -0,0 +1,86 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 31, + Email: "me@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-31", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, + }, + }, + } + + handler := &AuthHandler{ + userService: service.NewUserService(repo, nil, nil, nil), + } + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31}) + + handler.GetCurrentUser(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + require.Equal(t, "linuxdo", avatarSource["source"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) + require.Equal(t, "linuxdo", usernameSource["source"]) +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index f4ddf890caa50e36acb89bdeb39c0a0ef24d4cc8..dc68a466287fb5c2e9184dcc667262249efb5825 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,11 +1,13 @@ package handler import ( + "context" "log/slog" "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -76,9 +78,24 @@ type AuthResponse struct { User *dto.User `json:"user"` } +func ensureLoginUserActive(user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if !user.IsActive() { + return service.ErrUserNotActive + } + return nil +} + // respondWithTokenPair 生成 Token 对并返回认证响应 // 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") if err != nil { slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) @@ -104,6 +121,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { }) } +func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error { + if h == nil || !h.isBackendModeEnabled(ctx) { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + settings, err := h.settingSvc.GetPublicSettings(ctx) + if err == nil && settings != nil { + return settings.BackendModeEnabled + } + return h.settingSvc.IsBackendModeEnabled(ctx) +} + // Register handles user registration // POST /api/v1/auth/register func (h *AuthHandler) Register(c *gin.Context) { @@ -177,6 +222,11 @@ func (h *AuthHandler) Login(c *gin.Context) { } _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } + // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { // Create a temporary login session for 2FA @@ -194,11 +244,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - // Backend mode: only admin can login - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") - return - } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) h.respondWithTokenPair(c, user) } @@ -262,16 +308,80 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensureLoginUserActive(user); err != nil { + response.ErrorFrom(c, err) + return + } - // Backend mode: only admin can login (check BEFORE deleting session) - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) return } + if session.PendingOAuthBind != nil { + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + + pendingSession, err := pendingSvc.GetBrowserSession( + c.Request.Context(), + session.PendingOAuthBind.PendingSessionToken, + session.PendingOAuthBind.BrowserSessionKey, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{}) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding( + c.Request.Context(), + h.entClient(), + h.authService, + h.userService, + pendingSession, + decision, + &user.ID, + true, + true, + ); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession( + c.Request.Context(), + pendingSession.SessionToken, + pendingSession.BrowserSessionKey, + ); err != nil { + response.ErrorFrom(c, err) + return + } + + secureCookie := isRequestHTTPS(c) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + + user, err = h.userService.GetByID(c.Request.Context(), session.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + // Delete the login session (only after all checks pass) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + if session.PendingOAuthBind == nil { + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + } + h.respondWithTokenPair(c, user) } @@ -290,8 +400,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { return } + identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user) + if err != nil { + response.ErrorFrom(c, err) + return + } + type UserResponse struct { - *dto.User + userProfileResponse RunMode string `json:"run_mode"` } @@ -300,7 +416,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { runMode = h.cfg.RunMode } - response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) + response.Success(c, UserResponse{ + userProfileResponse: userProfileResponseFromService(user, identities), + RunMode: runMode, + }) } // ValidatePromoCodeRequest 验证优惠码请求 @@ -578,6 +697,8 @@ func (h *AuthHandler) Logout(c *gin.Context) { // 不影响登出流程 } } + h.consumePendingOAuthSessionOnLogout(c) + clearOAuthLogoutCookies(c) response.Success(c, LogoutResponse{ Message: "Logged out successfully", @@ -598,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { return } - if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil { slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) response.InternalError(c, "Failed to revoke sessions") return diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0c7c2da7ab49a5dcfb557dfcee1a75300687a8de..2ef059636b7a0f068777aa5617a00282ac637e7e 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -13,10 +15,13 @@ import ( "time" "unicode/utf8" + dbent "github.com/Wei-Shaw/sub2api/ent" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -25,17 +30,24 @@ import ( ) const ( - linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" - linuxDoOAuthStateCookieName = "linuxdo_oauth_state" - linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" - linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" - linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - linuxDoOAuthDefaultRedirectTo = "/dashboard" - linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent" + linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user" + oauthBindAccessTokenCookieName = "oauth_bind_access_token" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" linuxDoOAuthMaxRedirectLen = 2048 linuxDoOAuthMaxFragmentValueLen = 512 linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") + + oauthIntentLogin = "login" + oauthIntentBindCurrentUser = "bind_current_user" ) type linuxDoTokenResponse struct { @@ -87,9 +99,29 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + intent := normalizeOAuthIntent(c.Query("intent")) + setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } else { + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + } codeChallenge := "" if cfg.UsePKCE { @@ -148,6 +180,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) @@ -161,6 +195,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) codeVerifier := "" if cfg.UsePKCE { @@ -198,52 +239,204 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") return } + compatEmail := strings.TrimSpace(email) // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 if subject != "" { email = linuxDoSyntheticEmail(subject) } + identityKey := service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: identityKey, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } - // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey) if err != nil { - if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") - return - } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityKey, + TargetUserID: &existingIdentityUser.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + redirectToFrontendCallback(c, frontendCallback) return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createLinuxDoOAuthChoicePendingSession( + c, + identityKey, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + h.isForceEmailOnThirdPartySignup(c.Request.Context()), + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + switch len(userEntity) { + case 0: + return nil, nil + case 1: + return userEntity[0], nil + default: + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } +} + +func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + if forceEmailOnSignup && compatEmailUser == nil { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + TargetUserID: targetUserID, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) } type completeLinuxDoOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -256,17 +449,87 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) if err != nil { response.ErrorFrom(c, err) return } + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -303,7 +566,7 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { + if strings.TrimSpace(codeVerifier) != "" { form.Set("code_verifier", codeVerifier) } @@ -353,11 +616,11 @@ func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, -) (email string, username string, subject string, err error) { +) (email string, username string, subject string, displayName string, avatarURL string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { - return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). @@ -366,16 +629,16 @@ func linuxDoFetchUserInfo( SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { - return "", "", "", fmt.Errorf("request userinfo: %w", err) + return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { - return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } -func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), @@ -400,12 +663,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s getGJSON(body, "user.id"), ) + displayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "user.name"), + getGJSON(body, "user.username"), + username, + ) + avatarURL = firstNonEmpty( + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "picture"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) + subject = strings.TrimSpace(subject) if subject == "" { - return "", "", "", errors.New("userinfo missing id field") + return "", "", "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { - return "", "", "", errors.New("userinfo returned invalid id field") + return "", "", "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) @@ -418,8 +698,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s if username == "" { username = "linuxdo_" + subject } + displayName = strings.TrimSpace(displayName) + if displayName == "" { + displayName = username + } + avatarURL = strings.TrimSpace(avatarURL) - return email, username, subject, nil + return email, username, subject, displayName, avatarURL, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { @@ -436,7 +721,7 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - if cfg.UsePKCE { + if strings.TrimSpace(codeChallenge) != "" { q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") } @@ -670,6 +955,30 @@ func clearCookie(c *gin.Context, name string, secure bool) { }) } +func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthBindAccessTokenCookieName, + Value: "", + Path: oauthBindAccessTokenCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthBindAccessTokenCookieName, + Value: url.QueryEscape(strings.TrimSpace(token)), + Path: oauthBindAccessTokenCookiePath, + MaxAge: linuxDoOAuthCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + func truncateFragmentValue(value string) string { value = strings.TrimSpace(value) if value == "" { @@ -728,3 +1037,127 @@ func linuxDoSyntheticEmail(subject string) string { } return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain } + +func normalizeOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", oauthIntentLogin: + return oauthIntentLogin + case "bind", oauthIntentBindCurrentUser: + return oauthIntentBindCurrentUser + default: + return oauthIntentLogin + } +} + +func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) { + userID, err := h.resolveOAuthBindTargetUserID(c) + if err != nil || userID == nil || *userID <= 0 { + return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required") + } + return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) { + const bearerPrefix = "Bearer " + + authHeader := strings.TrimSpace(c.GetHeader("Authorization")) + if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) { + response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")) + return + } + + token := strings.TrimSpace(authHeader[len(bearerPrefix):]) + if token == "" { + response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")) + return + } + + setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c)) + c.Status(http.StatusNoContent) + c.Writer.WriteHeaderNow() +} + +func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) { + if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return &subject.UserID, nil + } + if h == nil || h.authService == nil || h.userService == nil { + return nil, service.ErrInvalidToken + } + + ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName) + clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c)) + if err != nil { + return nil, err + } + + tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value)) + if err != nil { + return nil, err + } + if tokenString == "" { + return nil, service.ErrInvalidToken + } + + claims, err := h.authService.ValidateToken(tokenString) + if err != nil { + return nil, err + } + user, err := h.userService.GetByID(c.Request.Context(), claims.UserID) + if err != nil { + return nil, err + } + if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion { + return nil, service.ErrInvalidToken + } + return &user.ID, nil +} + +func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) { + value, err := readCookieDecoded(c, cookieName) + if err != nil { + return 0, err + } + return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret()) +} + +func (h *AuthHandler) oauthBindCookieSecret() string { + if h == nil || h.cfg == nil { + return "" + } + return strings.TrimSpace(h.cfg.JWT.Secret) +} + +func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) { + secret = strings.TrimSpace(secret) + if userID <= 0 || secret == "" { + return "", errors.New("invalid oauth bind cookie input") + } + payload := strconv.FormatInt(userID, 10) + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return payload + "." + signature, nil +} + +func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) { + secret = strings.TrimSpace(secret) + if secret == "" { + return 0, errors.New("missing oauth bind cookie secret") + } + payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".") + if !ok || payload == "" || signature == "" { + return 0, errors.New("invalid oauth bind cookie") + } + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write([]byte(payload)) + expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { + return 0, errors.New("invalid oauth bind cookie signature") + } + userID, err := strconv.ParseInt(payload, 10, 64) + if err != nil || userID <= 0 { + return 0, errors.New("invalid oauth bind cookie user") + } + return userID, nil +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index ff169c52ad694b76b53ea02892bbd05496aadd4a..8b01ab417f894ccc3006e08cd605f63068b49b1e 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -1,10 +1,24 @@ package handler import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "net/url" "strings" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -41,11 +55,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "alice", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "Alice", displayName) + require.Equal(t, "https://cdn.example/avatar.png", avatarURL) } func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { @@ -53,11 +69,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "linuxdo_123", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "linuxdo_123", displayName) + require.Equal(t, "", avatarURL) } func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { @@ -65,11 +83,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) require.Error(t, err) tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) - _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) require.Error(t, err) } @@ -106,3 +124,906 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "", singleLine("\n\t\r")) } + +func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42}) + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "connect.linux.do/oauth/authorize") + require.Contains(t, location, "client_id=linuxdo-client") + require.Contains(t, location, "code_challenge=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(42), userID) +} + +func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) { + handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil) + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=") + require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie)) +} + +func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, r.ParseForm()) + require.Empty(t, r.PostForm.Get("code_verifier")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) +} + +func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: "https://connect.linux.do/oauth/authorize", + TokenURL: "https://connect.linux.do/oauth/token", + UserInfoURL: "https://connect.linux.do/api/user", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("bind-cookie@example.com"). + SetUsername("bind-cookie-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + token, err := handler.authService.GenerateToken(&service.User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + PasswordHash: user.PasswordHash, + Role: user.Role, + Status: user.Status, + }) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath}) + c.Request = req + + handler.LinuxDoOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + + bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, user.ID, userID) + + accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName) + require.NotNil(t, accessTokenCookie) + require.Equal(t, -1, accessTokenCookie.MaxAge) +} + +func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) { + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{}) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil) + req.Header.Set("Authorization", "Bearer access-token-value") + c.Request = req + + handler.PrepareOAuthBindAccessTokenCookie(c) + + require.Equal(t, http.StatusNoContent, recorder.Code) + accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName) + require.NotNil(t, accessTokenCookie) + require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path) + require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge) + require.True(t, accessTokenCookie.HttpOnly) + require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value) +} + +func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("321")). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("321"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail) + require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) + require.Nil(t, completion["error"]) +} + +func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(linuxDoSyntheticEmail("654")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("654"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(" Legacy@Example.com "). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"]) + require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + +func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + +func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-1"). + SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Display", + "suggested_avatar_url": "https://cdn.example/linuxdo.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "LinuxDo Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("linuxdo-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-choice-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-choice-subject-1"). + SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-no-adoption"). + SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Legacy", + "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "linuxdo_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + +func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-conflict-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-conflict-subject"). + SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + LinuxDo: oauthCfg, + } + return handler, client +} diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0d4f94b1373ec225855604dfca871a32aa6abc7e --- /dev/null +++ b/backend/internal/handler/auth_oauth_logout_test.go @@ -0,0 +1,68 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("logout-pending-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("logout-subject-123"). + SetBrowserSessionKey("logout-browser-session-key"). + SetResolvedEmail("logout@example.com"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"}) + req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")}) + req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")}) + req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")}) + req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + cookies := recorder.Result().Cookies() + for _, name := range []string{ + oauthPendingSessionCookieName, + oauthPendingBrowserCookieName, + oauthBindAccessTokenCookieName, + linuxDoOAuthStateCookieName, + oidcOAuthStateCookieName, + wechatOAuthStateCookieName, + wechatPaymentOAuthStateName, + } { + cookie := findCookie(cookies, name) + require.NotNil(t, cookie, name) + require.Equal(t, -1, cookie.MaxAge, name) + require.True(t, cookie.HttpOnly, name) + } + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..604ad903a5568544e3e0ff54af8b84fe89bfa9dd --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -0,0 +1,1944 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" + "github.com/gin-gonic/gin" +) + +const ( + oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" + oauthPendingBrowserCookieName = "oauth_pending_browser_session" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth" + oauthPendingSessionCookieName = "oauth_pending_session" + oauthPendingCookieMaxAgeSec = 10 * 60 + oauthPendingChoiceStep = "choose_account_action_required" + + oauthCompletionResponseKey = "completion_response" +) + +var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error + +type oauthPendingSessionPayload struct { + Intent string + Identity service.PendingAuthIdentityKey + TargetUserID *int64 + ResolvedEmail string + RedirectTo string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + CompletionResponse map[string]any +} + +type oauthAdoptionDecisionRequest struct { + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type bindPendingOAuthLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type createPendingOAuthAccountRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code,omitempty"` + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type sendPendingOAuthVerifyCodeRequest struct { + Email string `json:"email" binding:"required,email"` + TurnstileToken string `json:"turnstile_token,omitempty"` + PendingAuthToken string `json:"pending_auth_token,omitempty"` + PendingOAuthToken string `json:"pending_oauth_token,omitempty"` +} + +func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { + if h == nil || h.authService == nil || h.authService.EntClient() == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil +} + +func generateOAuthPendingBrowserSession() (string, error) { + return oauth.GenerateState() +} + +func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: encodeCookieValue(sessionKey), + Path: oauthPendingBrowserCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: "", + Path: oauthPendingBrowserCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingBrowserCookieName) +} + +func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: encodeCookieValue(sessionToken), + Path: oauthPendingSessionCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: "", + Path: oauthPendingSessionCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingSessionCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingSessionCookieName) +} + +func redirectToFrontendCallback(c *gin.Context, frontendCallback string) { + u, err := url.Parse(frontendCallback) + if err != nil { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = "" + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error { + svc, err := h.pendingIdentityService() + if err != nil { + return err + } + + session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ + Intent: strings.TrimSpace(payload.Intent), + Identity: payload.Identity, + TargetUserID: payload.TargetUserID, + ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), + RedirectTo: strings.TrimSpace(payload.RedirectTo), + BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), + UpstreamIdentityClaims: payload.UpstreamIdentityClaims, + LocalFlowState: map[string]any{ + oauthCompletionResponseKey: payload.CompletionResponse, + }, + }) + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + + setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c)) + return nil +} + +func readCompletionResponse(session map[string]any) (map[string]any, bool) { + if len(session) == 0 { + return nil, false + } + value, ok := session[oauthCompletionResponseKey] + if !ok { + return nil, false + } + result, ok := value.(map[string]any) + if !ok { + return nil, false + } + return result, true +} + +func clonePendingMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any { + payload, _ := readCompletionResponse(session.LocalFlowState) + merged := clonePendingMap(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := merged["redirect"]; !exists { + merged["redirect"] = session.RedirectTo + } + } + for key, value := range overrides { + if value == nil { + delete(merged, key) + continue + } + merged[key] = value + } + applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims) + return merged +} + +func pendingSessionStringValue(values map[string]any, key string) string { + if len(values) == 0 { + return "" + } + raw, ok := values[key] + if !ok { + return "" + } + value, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(value) +} + +func pendingSessionWantsInvitation(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") +} + +func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool { + if session == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { + return false + } + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + return false + } + if pendingSessionWantsInvitation(payload) { + return false + } + return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == "" +} + +func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error { + if session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if strings.TrimSpace(session.Intent) != oauthIntentLogin { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + payload, _ := readCompletionResponse(session.LocalFlowState) + if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + return nil +} + +func buildLegacyCompleteRegistrationPendingResponse( + session *dbent.PendingAuthSession, + forceEmailOnSignup bool, + emailVerificationRequired bool, +) map[string]any { + completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + })) + + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + if _, exists := completionResponse["email"]; !exists { + completionResponse["email"] = email + } + if _, exists := completionResponse["resolved_email"]; !exists { + completionResponse["resolved_email"] = email + } + } + if _, exists := completionResponse["choice_reason"]; !exists { + switch { + case forceEmailOnSignup: + completionResponse["choice_reason"] = "force_email_on_signup" + case emailVerificationRequired: + completionResponse["choice_reason"] = "email_verification_required" + default: + completionResponse["choice_reason"] = "third_party_signup" + } + } + return completionResponse +} + +func (h *AuthHandler) legacyCompleteRegistrationSessionStatus( + c *gin.Context, + session *dbent.PendingAuthSession, +) (*dbent.PendingAuthSession, bool, error) { + if session == nil { + return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) + if step := pendingSessionStringValue(payload, "step"); step != "" { + return session, true, nil + } + + emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context()) + forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context()) + if !emailVerificationRequired && !forceEmailOnSignup { + return session, false, nil + } + + client := h.entClient() + if client == nil { + return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + updatedSession, err := updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + strings.TrimSpace(session.Intent), + strings.TrimSpace(session.ResolvedEmail), + nil, + buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired), + ) + if err != nil { + return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + return updatedSession, true, nil +} + +func (r oauthAdoptionDecisionRequest) hasDecision() bool { + return r.AdoptDisplayName != nil || r.AdoptAvatar != nil +} + +func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { + var req oauthAdoptionDecisionRequest + if c == nil || c.Request == nil || c.Request.Body == nil { + return req, nil + } + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + return req, nil + } + return req, err + } + return req, nil +} + +func cloneOAuthMetadata(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any { + merged := cloneOAuthMetadata(base) + for key, value := range overlay { + merged[key] = value + } + return merged +} + +func normalizeAdoptedOAuthDisplayName(value string) string { + value = strings.TrimSpace(value) + if len([]rune(value)) > 100 { + value = string([]rune(value)[:100]) + } + return value +} + +func (h *AuthHandler) entClient() *dbent.Client { + if h == nil || h.authService == nil { + return nil + } + return h.authService.EntClient() +} + +func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx) + if err != nil || defaults == nil { + return false + } + return defaults.ForceEmailOnThirdPartySignup +} + +func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return findActiveUserByID(ctx, client, record.UserID) +} + +func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } +func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } +func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } +func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") } + +func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "linuxdo") +} + +func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") } + +func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "wechat") +} + +func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "") +} + +// SendPendingOAuthVerifyCode sends a verification code for a browser-bound +// pending OAuth account-creation flow. +// POST /api/v1/auth/oauth/pending/send-verify-code +func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { + var req sendPendingOAuthVerifyCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return + } + + _, session, _, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } else if err != nil && !errors.Is(err, service.ErrUserNotFound) { + response.ErrorFrom(c, err) + return + } + + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, SendVerifyCodeResponse{ + Message: "Verification code sent successfully", + Countdown: result.Countdown, + }) +} + +func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + existing, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)). + Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err) + } + if existing != nil && !req.hasDecision() { + return existing, nil + } + if existing == nil && !req.hasDecision() { + return nil, nil + } + + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if existing != nil { + input.AdoptDisplayName = existing.AdoptDisplayName + input.AdoptAvatar = existing.AdoptAvatar + input.IdentityID = existing.IdentityID + } + if req.AdoptDisplayName != nil { + input.AdoptDisplayName = *req.AdoptDisplayName + } + if req.AdoptAvatar != nil { + input.AdoptAvatar = *req.AdoptAvatar + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func (h *AuthHandler) ensurePendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req) + if err != nil { + return nil, err + } + if decision != nil { + return decision, nil + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + }) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func updatePendingOAuthSessionProgress( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + intent string, + resolvedEmail string, + targetUserID *int64, + completionResponse map[string]any, +) (*dbent.PendingAuthSession, error) { + if client == nil || session == nil { + return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + + localFlowState := clonePendingMap(session.LocalFlowState) + localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse) + + update := client.PendingAuthSession.UpdateOneID(session.ID). + SetIntent(strings.TrimSpace(intent)). + SetResolvedEmail(strings.TrimSpace(resolvedEmail)). + SetLocalFlowState(localFlowState) + if targetUserID != nil && *targetUserID > 0 { + update = update.SetTargetUserID(*targetUserID) + } else { + update = update.ClearTargetUserID() + } + return update.Save(ctx) +} + +func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { + if session == nil { + return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return *session.TargetUserID, nil + } + email := strings.TrimSpace(session.ResolvedEmail) + if email == "" { + return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") + } + + userEntity, err := findUserByNormalizedEmail(ctx, client, email) + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") + } + return 0, err + } + return userEntity.ID, nil +} + +func userNormalizedEmailPredicate(email string) predicate.User { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.P(func(b *entsql.Builder) { + b.WriteString("LOWER(TRIM("). + Ident(s.C(dbuser.FieldEmail)). + WriteString(")) = "). + Arg(normalized) + })) + }) +} + +func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) { + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + matches, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } + return matches[0], nil +} + +func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error { + if client == nil || session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity == nil || identity.UserID <= 0 { + return nil + } + + activeOwner, err := findActiveUserByID(ctx, client, identity.UserID) + if err != nil { + return err + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return nil +} + +func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { + if session == nil { + return nil + } + switch strings.TrimSpace(session.ProviderType) { + case "oidc": + issuer := strings.TrimSpace(session.ProviderKey) + if issuer == "" { + issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + } + if issuer == "" { + return nil + } + return &issuer + default: + issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + if issuer == "" { + return nil + } + return &issuer + } +} + +func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") { + return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID) + } + + client := tx.Client() + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil { + if identity.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, identity.UserID) + if err != nil { + return nil, err + } + if activeOwner != nil { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return client.AuthIdentity.UpdateOneID(identity.ID). + SetUserID(userID). + Save(ctx) + } + return identity, nil + } + + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(strings.TrimSpace(session.ProviderType)). + SetProviderKey(strings.TrimSpace(session.ProviderKey)). + SetProviderSubject(strings.TrimSpace(session.ProviderSubject)). + SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + return create.Save(ctx) +} + +func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(providerKey) + channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) + channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) + channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) + metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) + + identityRecords, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + All(ctx) + if err != nil { + return nil, err + } + identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey) + if err != nil { + return nil, err + } + + var legacyOpenIDIdentity *dbent.AuthIdentity + if channelSubject != "" && channelSubject != providerSubject { + legacyOpenIDRecords, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(channelSubject), + ). + All(ctx) + if err != nil { + return nil, err + } + legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey) + if err != nil { + return nil, err + } + } + + switch { + case identity != nil: + update := client.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if identity.UserID != userID { + update = update.SetUserID(userID) + } + if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey { + update = update.SetProviderKey(providerKey) + } + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + case legacyOpenIDIdentity != nil: + update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + default: + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, err + } + } + + if channel == "" || channelAppID == "" || channelSubject == "" { + return identity, nil + } + + channelRecords, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return nil, err + } + channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey) + if err != nil { + return nil, err + } + + channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) + if channelRecord == nil { + if _, err := client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channel). + SetChannelAppID(channelAppID). + SetChannelSubject(channelSubject). + SetMetadata(channelMetadata). + Save(ctx); err != nil { + return nil, err + } + return identity, nil + } + + updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + SetIdentityID(identity.ID). + SetMetadata(channelMetadata) + if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey { + updateChannel = updateChannel.SetProviderKey(providerKey) + } + _, err = updateChannel.Save(ctx) + if err != nil { + return nil, err + } + return identity, nil +} + +func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) { + var preferred *dbent.AuthIdentity + var fallback *dbent.AuthIdentity + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, record.UserID) + if err != nil { + return nil, false, err + } + if activeOwner != nil { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) { + var preferred *dbent.AuthIdentityChannel + var fallback *dbent.AuthIdentityChannel + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID) + if err != nil { + return nil, false, err + } + if activeOwner != nil { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) { + if client == nil || userID <= 0 { + return nil, nil + } + userEntity, err := client.User.Get(ctx, userID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) { + return nil, service.ErrUserNotActive + } + return userEntity, nil +} + +func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { + if channel == nil { + return map[string]any{} + } + return cloneOAuthMetadata(channel.Metadata) +} + +func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { + if session == nil || decision == nil { + return false + } + switch strings.ToLower(strings.TrimSpace(session.Intent)) { + case "bind_current_user", "login", "adopt_existing_user_by_email": + return true + default: + return decision.AdoptDisplayName || decision.AdoptAvatar + } +} + +func shouldSkipAvatarAdoption(err error) bool { + return errors.Is(err, service.ErrAvatarInvalid) || + errors.Is(err, service.ErrAvatarTooLarge) || + errors.Is(err, service.ErrAvatarNotImage) +} + +func applyPendingOAuthBinding( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, + forceBind bool, + applyFirstBindDefaults bool, +) error { + if client == nil || session == nil { + return nil + } + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults) + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil { + return err + } + return tx.Commit() +} + +func applyPendingOAuthBindingTx( + ctx context.Context, + tx *dbent.Tx, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, + forceBind bool, + applyFirstBindDefaults bool, +) error { + if tx == nil || session == nil { + return nil + } + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { + return nil + } + + targetUserID := int64(0) + if overrideUserID != nil && *overrideUserID > 0 { + targetUserID = *overrideUserID + } else { + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session) + if err != nil { + return err + } + targetUserID = resolvedUserID + } + + adoptedDisplayName := "" + if decision != nil && decision.AdoptDisplayName { + adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) + } + adoptedAvatarURL := "" + if decision != nil && decision.AdoptAvatar { + adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") + } + shouldAdoptAvatar := false + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { + if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil { + shouldAdoptAvatar = true + } else if !shouldSkipAvatarAdoption(err) { + return err + } + } + + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { + if err := tx.Client().User.UpdateOneID(targetUserID). + SetUsername(adoptedDisplayName). + Exec(ctx); err != nil { + return err + } + } + + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + if err != nil { + return err + } + + metadata := cloneOAuthMetadata(identity.Metadata) + for key, value := range session.UpstreamIdentityClaims { + metadata[key] = value + } + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { + metadata["display_name"] = adoptedDisplayName + } + if shouldAdoptAvatar { + metadata["avatar_url"] = adoptedAvatarURL + } + + updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) + } + if _, err := updateIdentity.Save(ctx); err != nil { + return err + } + + if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { + if _, err := tx.Client().IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(identity.ID), + identityadoptiondecision.IDNEQ(decision.ID), + ). + ClearIdentityID(). + Save(ctx); err != nil { + return err + } + if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). + SetIdentityID(identity.ID). + Save(ctx); err != nil { + return err + } + } + + if applyFirstBindDefaults && authService != nil { + if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil { + return err + } + } + + if shouldAdoptAvatar && userService != nil { + if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil { + return err + } + } + + return nil +} + +func consumePendingOAuthBrowserSessionTx( + ctx context.Context, + tx *dbent.Tx, + session *dbent.PendingAuthSession, +) error { + if tx == nil || session == nil { + return service.ErrPendingAuthSessionNotFound + } + + storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID) + if err != nil { + if dbent.IsNotFound(err) { + return service.ErrPendingAuthSessionNotFound + } + return err + } + + now := time.Now().UTC() + if storedSession.ConsumedAt != nil { + return service.ErrPendingAuthSessionConsumed + } + if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) { + return service.ErrPendingAuthSessionExpired + } + if strings.TrimSpace(storedSession.BrowserSessionKey) != "" && + strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return service.ErrPendingAuthBrowserMismatch + } + + if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx); err != nil { + return err + } + + return nil +} + +func applyPendingOAuthAdoptionAndConsumeSession( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + userID int64, +) error { + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if session == nil || userID <= 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil { + return err + } + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + return err + } + return tx.Commit() +} + +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + return applyPendingOAuthBinding( + ctx, + client, + authService, + userService, + session, + decision, + overrideUserID, + false, + strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"), + ) +} + +func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { + if len(payload) == 0 || len(upstream) == 0 { + return + } + + displayName := pendingSessionStringValue(upstream, "suggested_display_name") + avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url") + + if displayName != "" { + if _, exists := payload["suggested_display_name"]; !exists { + payload["suggested_display_name"] = displayName + } + } + if avatarURL != "" { + if _, exists := payload["suggested_avatar_url"]; !exists { + payload["suggested_avatar_url"] = avatarURL + } + } + if displayName != "" || avatarURL != "" { + payload["adoption_required"] = true + } +} + +func pendingOAuthIdentityExistsForUser( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + userID int64, +) (bool, error) { + if client == nil || session == nil || userID <= 0 { + return false, nil + } + + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + if providerType == "" || providerSubject == "" { + return false, nil + } + + query := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderSubjectEQ(providerSubject), + authidentity.UserIDEQ(userID), + ) + if strings.EqualFold(providerType, "wechat") { + query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...)) + } else if providerKey != "" { + query = query.Where(authidentity.ProviderKeyEQ(providerKey)) + } + + count, err := query.Count(ctx) + if err != nil { + return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return count > 0, nil +} + +func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt( + ctx context.Context, + session *dbent.PendingAuthSession, + payload map[string]any, +) (bool, error) { + if session == nil || len(payload) == 0 { + return false, nil + } + if !pendingOAuthCompletionCanIssueTokenPair(session, payload) { + return false, nil + } + if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" && + pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" { + return false, nil + } + + return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID) +} + +func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + return svc, session, clearCookies, nil +} + +func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) { + if c == nil || c.Request == nil { + return + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + return + } + _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) +} + +func clearOAuthLogoutCookies(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + clearOAuthBindAccessTokenCookie(c, secureCookie) + + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + + oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) + oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) + oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) + + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + + wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie) +} + +func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { + completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) + payload := gin.H{ + "auth_result": "pending_session", + "provider": strings.TrimSpace(session.ProviderType), + "intent": strings.TrimSpace(session.Intent), + } + for key, value := range completionResponse { + payload[key] = value + } + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + payload["email"] = email + } + return payload +} + +func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any { + normalized := clonePendingMap(payload) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(normalized, key) + } + step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) + switch step { + case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": + normalized["step"] = oauthPendingChoiceStep + } + if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) { + normalized["adoption_required"] = true + } + if _, exists := normalized["adoption_required"]; !exists { + if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields { + normalized["adoption_required"] = true + } + } + return normalized +} + +func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any { + response := mergePendingCompletionResponse(session, map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "force_email_on_signup": true, + "email_binding_required": true, + "existing_account_bindable": true, + }) + if email = strings.TrimSpace(email); email != "" { + response["email"] = email + response["resolved_email"] = email + } + return response +} + +func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState( + c *gin.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + targetUser *dbent.User, + email string, +) (*dbent.PendingAuthSession, error) { + completionResponse := pendingOAuthChoiceCompletionResponse(session, email) + var targetUserID *int64 + if targetUser != nil && targetUser.ID > 0 { + targetUserID = &targetUser.ID + } + session, err := updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + strings.TrimSpace(session.Intent), + email, + targetUserID, + completionResponse, + ) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + return session, nil +} + +func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { + var req bindPendingOAuthLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID { + response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { + tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession( + c.Request.Context(), + user.ID, + user.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + if err != nil { + response.InternalError(c, "Failed to create 2FA session") + return + } + response.Success(c, TotpLoginResponse{ + Requires2FA: true, + TempToken: tempToken, + UserEmailMasked: service.MaskEmail(user.Email), + }) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + response.InternalError(c, "Failed to generate token pair") + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func respondPendingOAuthBindingApplyError(c *gin.Context, err error) { + if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError { + response.ErrorFrom(c, err) + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) +} + +func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) { + var req createPendingOAuthAccountRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email) + if err != nil { + switch { + case errors.Is(err, service.ErrUserNotFound): + existingUser = nil + case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError: + response.ErrorFrom(c, err) + return + default: + response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) + return + } + } + if existingUser != nil { + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( + c.Request.Context(), + email, + req.Password, + strings.TrimSpace(req.VerifyCode), + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + if errors.Is(err, service.ErrEmailExists) { + existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email) + if lookupErr != nil { + response.ErrorFrom(c, lookupErr) + return + } + session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + response.ErrorFrom(c, err) + return + } + + rollbackCreatedUser := func(originalErr error) bool { + if user == nil || user.ID <= 0 { + return false + } + if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation( + c.Request.Context(), + user.ID, + strings.TrimSpace(req.InvitationCode), + ); rollbackErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer( + "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED", + "failed to rollback pending oauth account creation", + ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr))) + return true + } + user = nil + return false + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, err) + return + } + + tx, err := client.Tx(c.Request.Context()) + if err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(c.Request.Context(), tx) + + if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + respondPendingOAuthBindingApplyError(c, err) + return + } + + if err := h.authService.FinalizeOAuthEmailAccount( + txCtx, + user, + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, err) + return + } + + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + clearCookies() + response.ErrorFrom(c, err) + return + } + + if pendingOAuthCreateAccountPreCommitHook != nil { + if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + respondPendingOAuthBindingApplyError(c, err) + return + } + } + + if err := tx.Commit(); err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. +// POST /api/v1/auth/oauth/pending/exchange +func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c) + if err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + payload, ok := readCompletionResponse(session.LocalFlowState) + if !ok { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + payload = normalizePendingOAuthCompletionResponse(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := payload["redirect"]; !exists { + payload["redirect"] = session.RedirectTo + } + } + applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload) + var loginUser *service.User + if canIssueTokenPair { + loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := ensureLoginUserActive(loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + } + skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if skipAdoptionPrompt { + delete(payload, "adoption_required") + } + + if pendingSessionWantsInvitation(payload) { + if adoptionDecision.hasDecision() { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + _ = decision + } + response.Success(c, payload) + return + } + if !adoptionDecision.hasDecision() { + adoptionRequired, _ := payload["adoption_required"].(bool) + if adoptionRequired { + response.Success(c, payload) + return + } + } + + decisionReq := adoptionDecision + if !decisionReq.hasDecision() { + adoptDisplayName := false + adoptAvatar := false + decisionReq = oauthAdoptionDecisionRequest{ + AdoptDisplayName: &adoptDisplayName, + AdoptAvatar: &adoptAvatar, + } + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + + if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + if canIssueTokenPair { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "") + if err != nil { + clearCookies() + response.InternalError(c, "Failed to generate token pair") + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID) + payload["access_token"] = tokenPair.AccessToken + payload["refresh_token"] = tokenPair.RefreshToken + payload["expires_in"] = tokenPair.ExpiresIn + payload["token_type"] = "Bearer" + } + + clearCookies() + response.Success(c, payload) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a4b7a2979ece25915e1eb7d7c523e51dab2da64a --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -0,0 +1,2995 @@ +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { + payload := map[string]any{ + "access_token": "token", + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Alice", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { + payload := map[string]any{ + "suggested_display_name": "Existing", + "adoption_required": false, + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Existing", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false) + + cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, cookie) + require.Equal(t, "/api/v1/auth/oauth", cookie.Path) +} + +func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("linuxdo-123@linuxdo-connect.invalid"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "https://cdn.example/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Alice Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err = client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "Alice Example", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.NotNil(t, avatar) + require.Equal(t, "remote_url", avatar.StorageProvider) + require.Equal(t, "https://cdn.example/alice.png", avatar.URL) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("invalid-avatar@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-invalid-avatar-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invalid-avatar-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-invalid-avatar-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "/avatars/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invalid-avatar-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + _, hasAdoptedAvatar := identity.Metadata["avatar_url"] + require.False(t, hasAdoptedAvatar) + + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.Nil(t, avatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-pending-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("bind-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Bound Example", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/settings/profile", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Bound Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"]) + require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"]) + _, hasDisplayName := identity.Metadata["display_name"] + require.False(t, hasDisplayName) + _, hasAvatarURL := identity.Metadata["avatar_url"] + require.False(t, hasAvatarURL) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("bind-conflict-target@example.com"). + SetUsername("target-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerUser, err := client.User.Create(). + SetEmail("bind-conflict-owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + existingIdentity, err := client.AuthIdentity.Create(). + SetUserID(ownerUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetMetadata(map[string]any{"username": "owner-user"}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-conflict-session-token"). + SetIntent("bind_current_user"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("conflict-123"). + SetTargetUserID(targetUser.ID). + SetResolvedEmail(targetUser.Email). + SetBrowserSessionKey("bind-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Conflict Example", + "suggested_avatar_url": "https://cdn.example/conflict.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"]) + + identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID) + require.NoError(t, err) + require.Equal(t, ownerUser.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-false@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-false-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-false-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-false-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Login Example", + "suggested_avatar_url": "https://cdn.example/login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("login-false-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginReassignsExistingDecisionIdentityReference(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-reassign@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + existingIdentity, err := client.AuthIdentity.Create(). + SetUserID(userEntity.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + previousSession, err := client.PendingAuthSession.Create(). + SetSessionToken("login-reassign-previous-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-reassign-previous-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "previous-access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previousDecision, err := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(previousSession.ID). + SetIdentityID(existingIdentity.ID). + SetAdoptDisplayName(true). + SetAdoptAvatar(true). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-reassign-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-reassign-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-reassign-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Login Reassign", + "suggested_avatar_url": "https://cdn.example/login-reassign.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(session.ID). + SetAdoptDisplayName(false). + SetAdoptAvatar(false). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-reassign-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + reloadedPrevious, err := client.IdentityAdoptionDecision.Get(ctx, previousDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedPrevious.IdentityID) + + currentDecision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, currentDecision.IdentityID) + require.Equal(t, existingIdentity.ID, *currentDecision.IdentityID) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("login-nodecision@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-nodecision-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("login-nodecision-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("login-nodecision-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "login-nodecision-user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("login-nodecision-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("existing-login@example.com"). + SetUsername("existing-login-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(userEntity.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetMetadata(map[string]any{ + "username": "existing-login-user", + }). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-login-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("existing-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Existing Login Example", + "suggested_avatar_url": "https://cdn.example/existing-login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + payload := decodeJSONResponseData(t, recorder) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.NotEqual(t, "legacy-access-token", payload["access_token"]) + require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) + require.NotContains(t, payload, "adoption_required") + + accessToken, ok := payload["access_token"].(string) + require.True(t, ok) + claims, err := handler.authService.ValidateToken(accessToken) + require.NoError(t, err) + reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion) + + decisionCount, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, decisionCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) + + completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.NotContains(t, completion, "access_token") + require.NotContains(t, completion, "refresh_token") + require.NotContains(t, completion, "expires_in") + require.NotContains(t, completion, "token_type") + require.Equal(t, "/dashboard", completion["redirect"]) +} + +func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("blocked@example.com"). + SetUsername("blocked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("blocked-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("blocked-subject-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("blocked-backend-mode-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("disabled-linked@example.com"). + SetUsername("disabled-linked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("disabled-linked-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("disabled-linked-subject"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("disabled-linked-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Disabled Linked User", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) { + payload := normalizePendingOAuthCompletionResponse(map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }) + + require.NotContains(t, payload, "access_token") + require.NotContains(t, payload, "refresh_token") + require.NotContains(t, payload, "expires_in") + require.NotContains(t, payload, "token_type") + require.Equal(t, "/dashboard", payload["redirect"]) +} + +func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("invitation-required-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invitation-123"). + SetBrowserSessionKey("invitation-required-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Invite Example", + "suggested_avatar_url": "https://cdn.example/invite.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "error": "invitation_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, "invitation_required", data["error"]) + require.Equal(t, true, data["adoption_required"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invitation-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-123"). + SetBrowserSessionKey("create-account-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Fresh OIDC User", + "suggested_avatar_url": "https://cdn.example/fresh.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, service.StatusActive, createdUser.Status) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-create-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, createdUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-123"). + SetBrowserSessionKey("existing-email-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, oauthIntentLogin, payload["intent"]) + require.Equal(t, "oidc", payload["provider"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) + require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) + require.Nil(t, storedSession.ConsumedAt) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-existing-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) +} + +func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-normalized-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-normalized-123"). + SetBrowserSessionKey("existing-email-normalized-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, oauthIntentLogin, payload["intent"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + +func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-send-code-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-send-code-123"). + SetBrowserSessionKey("existing-email-send-code-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "email_required", + }, + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")}) + ginCtx.Request = req + + handler.SendPendingOAuthVerifyCode(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + +func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-backend-mode-123"). + SetBrowserSessionKey("create-account-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) { + handler, _ := newOAuthPendingFlowTestHandler(t, false) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge) +} + +func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") + ctx := context.Background() + + conflictOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(conflictOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetMetadata(map[string]any{ + "username": "owner-user", + }). + Save(ctx) + require.NoError(t, err) + + invitation, err := client.RedeemCode.Create(). + SetCode("INVITE123"). + SetType(service.RedeemTypeInvitation). + SetStatus(service.StatusUnused). + SetValue(0). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-conflict-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetBrowserSessionKey("create-account-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusConflict, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID) + require.NoError(t, err) + require.Equal(t, service.StatusUnused, storedInvitation.Status) + require.Nil(t, storedInvitation.UsedBy) + require.Nil(t, storedInvitation.UsedAt) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + userRepoOptions: oauthPendingFlowUserRepoOptions{ + rejectDeleteWhileAuthIdentityExists: true, + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-finalize-failure-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-finalize-failure-123"). + SetBrowserSessionKey("create-account-finalize-failure-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error { + return errors.New("forced post-bind failure") + } + t.Cleanup(func() { + pendingOAuthCreateAccountPreCommitHook = nil + }) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-backend-mode-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-backend-mode-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-invalid-password-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-invalid-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + oldOwnerHash, err := handler.authService.HashPassword("old-secret") + require.NoError(t, err) + oldOwner, err := client.User.Create(). + SetEmail("old-owner@example.com"). + SetUsername("old-owner"). + SetPasswordHash(oldOwnerHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(oldOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-soft-deleted-123"). + SetMetadata(map[string]any{"username": "old-owner"}). + Save(ctx) + require.NoError(t, err) + + _, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx) + require.NoError(t, err) + + newOwnerHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + newOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(newOwnerHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-soft-deleted-owner-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-soft-deleted-123"). + SetTargetUserID(newOwner.ID). + SetResolvedEmail(newOwner.Email). + SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Recovered OIDC User", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err = client.AuthIdentity.Get(ctx, identity.ID) + require.NoError(t, err) + require.Equal(t, newOwner.ID, identity.UserID) +} + +func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3", + service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + firstSession, err := client.PendingAuthSession.Create(). + SetSessionToken("first-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-first-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("first-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + firstRecorder := httptest.NewRecorder() + firstGinCtx, _ := gin.CreateTestContext(firstRecorder) + firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody) + firstReq.Header.Set("Content-Type", "application/json") + firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)}) + firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")}) + firstGinCtx.Request = firstReq + + handler.BindOIDCOAuthLogin(firstGinCtx) + + require.Equal(t, http.StatusOK, firstRecorder.Code) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID) + require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID) + require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + + secondSession, err := client.PendingAuthSession.Create(). + SetSessionToken("second-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-second-456"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("second-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Second OIDC User", + "suggested_avatar_url": "https://cdn.example/second.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + secondRecorder := httptest.NewRecorder() + secondGinCtx, _ := gin.CreateTestContext(secondRecorder) + secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)}) + secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")}) + secondGinCtx.Request = secondReq + + handler.BindOIDCOAuthLogin(secondGinCtx) + + require.Equal(t, http.StatusOK, secondRecorder.Code) + + storedUser, err = client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) +} + +func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + _ = handler + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("resolve-target-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-target-123"). + SetResolvedEmail("owner@example.com"). + SetBrowserSessionKey("resolve-target-browser-session-key"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + require.NoError(t, err) + require.Equal(t, existingUser.ID, resolvedUserID) +} + +func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + }, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-2fa-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, true, data["requires_2fa"]) + require.Equal(t, "o***r@example.com", data["user_email_masked"]) + tempToken, ok := data["temp_token"].(string) + require.True(t, ok) + require.NotEmpty(t, tempToken) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.NotNil(t, loginSession) + require.NotNil(t, loginSession.PendingOAuthBind) + require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken) + require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + service.SettingKeyAuthSourceDefaultOIDCBalance: "8", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2", + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(4). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-2fa-pending-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-login-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(session.ID). + SetAdoptDisplayName(false). + SetAdoptAvatar(false). + Save(ctx) + require.NoError(t, err) + + tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession( + ctx, + existingUser.ID, + existingUser.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + require.NoError(t, err) + + code, err := totp.GenerateCode(secret, time.Now().UTC()) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)}) + ginCtx.Request = req + + handler.Login2FA(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + payload := decodeJSONResponseData(t, recorder) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + accessToken, ok := payload["access_token"].(string) + require.True(t, ok) + claims, err := handler.authService.ValidateToken(accessToken) + require.NoError(t, err) + reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-login-2fa-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.Nil(t, loginSession) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 9.5, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + require.Empty(t, defaultSubAssigner.calls) +} + +func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) +} + +func newOAuthPendingFlowTestHandlerWithEmailVerification( + t *testing.T, + invitationEnabled bool, + email string, + code string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + cache := &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + email: { + Code: code, + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + } + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) +} + +func newOAuthPendingFlowTestHandlerWithOptions( + t *testing.T, + invitationEnabled bool, + emailVerifyEnabled bool, + emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + invitationEnabled: invitationEnabled, + emailVerifyEnabled: emailVerifyEnabled, + emailCache: emailCache, + }) +} + +type oauthPendingFlowTestHandlerOptions struct { + invitationEnabled bool + emailVerifyEnabled bool + emailCache service.EmailCache + settingValues map[string]string + defaultSubAssigner service.DefaultSubscriptionAssigner + totpCache service.TotpCache + totpEncryptor service.SecretEncryptor + userRepoOptions oauthPendingFlowUserRepoOptions +} + +func newOAuthPendingFlowTestHandlerWithDependencies( + t *testing.T, + options oauthPendingFlowTestHandlerOptions, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_avatars ( + user_id INTEGER PRIMARY KEY, + storage_provider TEXT NOT NULL, + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT '', + byte_size INTEGER NOT NULL DEFAULT 0, + sha256 TEXT NOT NULL DEFAULT '', + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingValues := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + } + for key, value := range options.settingValues { + settingValues[key] = value + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) + userRepo := &oauthPendingFlowUserRepo{ + client: client, + options: options.userRepoOptions, + } + redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client} + var emailService *service.EmailService + if options.emailCache != nil { + emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + }, + }, options.emailCache) + } + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &oauthPendingFlowRefreshTokenCacheStub{}, + cfg, + settingSvc, + emailService, + nil, + nil, + nil, + options.defaultSubAssigner, + ) + userSvc := service.NewUserService(userRepo, nil, nil, nil) + var totpSvc *service.TotpService + if options.totpCache != nil || options.totpEncryptor != nil { + totpCache := options.totpCache + if totpCache == nil { + totpCache = &oauthPendingFlowTotpCacheStub{} + } + totpEncryptor := options.totpEncryptor + if totpEncryptor == nil { + totpEncryptor = oauthPendingFlowTotpEncryptorStub{} + } + totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil) + } + + return &AuthHandler{ + authService: authSvc, + userService: userSvc, + settingSvc: settingSvc, + totpService: totpSvc, + }, client +} + +func boolSettingValue(v bool) string { + if v { + return "true" + } + return "false" +} + +func boolPtr(v bool) *bool { + return &v +} + +type oauthPendingFlowSettingRepoStub struct { + values map[string]string +} + +func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type oauthPendingFlowRefreshTokenCacheStub struct{} + +type oauthPendingFlowEmailCacheStub struct { + verificationCodes map[string]*service.VerificationCodeData +} + +func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { + if s == nil || s.verificationCodes == nil { + return nil, nil + } + return s.verificationCodes[email], nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { + if s.verificationCodes == nil { + s.verificationCodes = map[string]*service.VerificationCodeData{} + } + s.verificationCodes[email] = data + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { + delete(s.verificationCodes, email) + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +type oauthPendingFlowRedeemCodeRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error { + panic("unexpected Create call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrRedeemCodeNotFound + } + return nil, err + } + notes := "" + if entity.Notes != nil { + notes = *entity.Notes + } + return &service.RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: notes, + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error { + if code == nil { + return nil + } + update := r.client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err +} + +func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error { + affected, err := r.client.RedeemCode.Update(). + Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). + SetStatus(service.StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return service.ErrRedeemCodeUsed + } + return nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var envelope struct { + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) + return envelope.Data +} + +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + return payload +} + +type oauthPendingFlowAvatarRecord struct { + StorageProvider string + URL string +} + +func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + if !rows.Next() { + require.NoError(t, rows.Err()) + return nil + } + + var record oauthPendingFlowAvatarRecord + require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL)) + require.NoError(t, rows.Err()) + return &record +} + +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + []any{userID, providerType, grantReason}, + &rows, + ) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + + require.True(t, rows.Next()) + var count int + require.NoError(t, rows.Scan(&count)) + require.False(t, rows.Next()) + return count +} + +type oauthPendingFlowUserRepo struct { + client *dbent.Client + options oauthPendingFlowUserRepoOptions +} + +type oauthPendingFlowUserRepoOptions struct { + rejectDeleteWhileAuthIdentityExists bool +} + +func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { + entity, err := r.client.User.Create(). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.ID = entity.ID + user.CreatedAt = entity.CreatedAt + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + entity, err := r.client.User.Get(ctx, id) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { + entity, err := r.client.User.UpdateOneID(user.ID). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + if r.options.rejectDeleteWhileAuthIdentityExists { + count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx) + if err != nil { + return err + } + if count > 0 { + return errors.New("cannot delete user while auth identities still exist") + } + } + return r.client.User.DeleteOneID(id).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var rows entsql.Rows + if err := driver.Query( + ctx, + `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ); err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + if err := driver.Exec( + ctx, + `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) +ON CONFLICT(user_id) DO UPDATE SET + storage_provider = excluded.storage_provider, + storage_key = excluded.storage_key, + url = excluded.url, + content_type = excluded.content_type, + byte_size = excluded.byte_size, + sha256 = excluded.sha256, + updated_at = CURRENT_TIMESTAMP`, + []any{ + userID, + input.StorageProvider, + input.StorageKey, + input.URL, + input.ContentType, + input.ByteSize, + input.SHA256, + }, + &result, + ); err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} + +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result) +} + +func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { + panic("unexpected DeductBalance call") +} + +func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) + return count > 0, err +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := r.client.AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: identity.ProviderType, + ProviderKey: identity.ProviderKey, + ProviderSubject: identity.ProviderSubject, + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: identity.Metadata, + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + return records, nil +} + +func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + update := r.client.User.UpdateOneID(userID) + if encryptedSecret == nil { + update = update.ClearTotpSecretEncrypted() + } else { + update = update.SetTotpSecretEncrypted(*encryptedSecret) + } + return update.Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(true). + SetTotpEnabledAt(time.Now().UTC()). + Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(false). + ClearTotpSecretEncrypted(). + ClearTotpEnabledAt(). + Exec(ctx) +} + +func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { + if entity == nil { + return nil + } + return &service.User{ + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + TotpSecretEncrypted: entity.TotpSecretEncrypted, + TotpEnabled: entity.TotpEnabled, + TotpEnabledAt: entity.TotpEnabledAt, + TotalRecharged: entity.TotalRecharged, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} + +type oauthPendingFlowDefaultSubAssignerStub struct { + calls []service.AssignSubscriptionInput +} + +func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + return nil, false, nil +} + +type oauthPendingFlowTotpCacheStub struct { + setupSessions map[int64]*service.TotpSetupSession + loginSessions map[string]*service.TotpLoginSession + verifyAttempts map[int64]int +} + +func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) { + if s == nil || s.setupSessions == nil { + return nil, nil + } + return s.setupSessions[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error { + if s.setupSessions == nil { + s.setupSessions = map[int64]*service.TotpSetupSession{} + } + s.setupSessions[userID] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error { + delete(s.setupSessions, userID) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) { + if s == nil || s.loginSessions == nil { + return nil, nil + } + return s.loginSessions[tempToken], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error { + if s.loginSessions == nil { + s.loginSessions = map[string]*service.TotpLoginSession{} + } + s.loginSessions[tempToken] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error { + delete(s.loginSessions, tempToken) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s.verifyAttempts == nil { + s.verifyAttempts = map[int64]int{} + } + s.verifyAttempts[userID]++ + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s == nil || s.verifyAttempts == nil { + return 0, nil + } + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error { + delete(s.verifyAttempts, userID) + return nil +} + +type oauthPendingFlowTotpEncryptorStub struct{} + +func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) { + return plaintext, nil +} + +func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) { + return ciphertext, nil +} diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go new file mode 100644 index 0000000000000000000000000000000000000000..47bad942e3b423408feda1c5ee53161d8f4cf6fd --- /dev/null +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -0,0 +1,57 @@ +package handler + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string { + t.Helper() + value, err := buildOAuthBindUserCookieValue(userID, secret) + require.NoError(t, err) + return value +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + decoded, err := decodeCookieValue(value) + require.NoError(t, err) + return decoded +} + +func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + require.NotEmpty(t, location) + + parsed, err := url.Parse(location) + require.NoError(t, err) + + rawValues := parsed.RawQuery + if rawValues == "" { + rawValues = parsed.Fragment + } + values, err := url.ParseQuery(rawValues) + require.NoError(t, err) + require.Equal(t, errorCode, values.Get("error")) + require.Equal(t, errorMessage, values.Get("error_message")) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 9d24df88ab1a08a2a436e120e7e367c65b66c881..0ac8871b963081ee575ebcc000c722d9f6b10db0 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -19,6 +19,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" @@ -32,14 +33,16 @@ import ( ) const ( - oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" - oidcOAuthStateCookieName = "oidc_oauth_state" - oidcOAuthVerifierCookie = "oidc_oauth_verifier" - oidcOAuthRedirectCookie = "oidc_oauth_redirect" - oidcOAuthNonceCookie = "oidc_oauth_nonce" - oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes - oidcOAuthDefaultRedirectTo = "/dashboard" - oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" + oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" + oidcOAuthStateCookieName = "oidc_oauth_state" + oidcOAuthVerifierCookie = "oidc_oauth_verifier" + oidcOAuthRedirectCookie = "oidc_oauth_redirect" + oidcOAuthNonceCookie = "oidc_oauth_nonce" + oidcOAuthIntentCookieName = "oidc_oauth_intent" + oidcOAuthBindUserCookieName = "oidc_oauth_bind_user" + oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + oidcOAuthDefaultRedirectTo = "/dashboard" + oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" ) type oidcTokenResponse struct { @@ -87,6 +90,8 @@ type oidcUserInfoClaims struct { Username string Subject string EmailVerified *bool + DisplayName string + AvatarURL string } type oidcJWKSet struct { @@ -127,9 +132,29 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) { redirectTo = oidcOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie) oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie) + intent := normalizeOAuthIntent(c.Query("intent")) + oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie) + } else { + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) + } codeChallenge := "" if cfg.UsePKCE { @@ -199,6 +224,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) }() expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName) @@ -212,6 +239,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = oidcOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) codeVerifier := "" if cfg.UsePKCE { @@ -258,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } - if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" { - redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") - return - } + var idClaims *oidcIDTokenClaims + if cfg.ValidateIDToken { + if strings.TrimSpace(tokenResp.IDToken) == "" { + redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") + return + } - idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce) - if err != nil { - log.Printf("[OIDC OAuth] id_token validation failed: %v", err) - redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "") - return + idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce) + if err != nil { + log.Printf("[OIDC OAuth] id_token validation failed: %v", err) + redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "") + return + } } userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp) @@ -277,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } - subject := strings.TrimSpace(idClaims.Subject) + subject := "" + if idClaims != nil { + subject = strings.TrimSpace(idClaims.Subject) + } if subject == "" { subject = strings.TrimSpace(userInfoClaims.Subject) } @@ -285,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "") return } - issuer := strings.TrimSpace(idClaims.Issuer) + issuer := "" + if idClaims != nil { + issuer = strings.TrimSpace(idClaims.Issuer) + } if issuer == "" { issuer = strings.TrimSpace(cfg.IssuerURL) } @@ -295,57 +338,252 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } emailVerified := userInfoClaims.EmailVerified - if emailVerified == nil { + if emailVerified == nil && idClaims != nil { emailVerified = idClaims.EmailVerified } - if cfg.RequireEmailVerified { - if emailVerified == nil || !*emailVerified { - redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") - return - } + if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) { + redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "") + return } identityKey := oidcIdentityKey(issuer, subject) - email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey) + compatEmail := strings.TrimSpace(userInfoClaims.Email) + if compatEmail == "" && idClaims != nil { + compatEmail = strings.TrimSpace(idClaims.Email) + } + email := oidcSyntheticEmailFromIdentityKey(identityKey) username := firstNonEmpty( userInfoClaims.Username, - idClaims.PreferredUsername, - idClaims.Name, + func() string { + if idClaims != nil { + return idClaims.PreferredUsername + } + return "" + }(), + func() string { + if idClaims != nil { + return idClaims.Name + } + return "" + }(), oidcFallbackUsername(subject), ) + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string { + if idClaims != nil { + return idClaims.Name + } + return "" + }(), username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "") + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, + Identity: identityRef, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } - // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) if err != nil { - if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") - return - } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityRef, + TargetUserID: &existingIdentityUser.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + if cfg.RequireEmailVerified { + if emailVerified == nil || !*emailVerified { + redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") + return + } + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOIDCOAuthChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + true, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + redirectToFrontendCallback(c, frontendCallback) return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + if err := h.createOIDCOAuthChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + compatEmail, + compatEmailUser, + h.isForceEmailOnThirdPartySignup(c.Request.Context()), + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := findUserByNormalizedEmail(ctx, client, email) + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createOIDCOAuthChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + } + if forceEmailOnSignup && compatEmailUser == nil { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + TargetUserID: targetUserID, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) } type completeOIDCOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating @@ -358,17 +596,87 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -405,7 +713,7 @@ func oidcExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { + if strings.TrimSpace(codeVerifier) != "" { form.Set("code_verifier", codeVerifier) } @@ -560,9 +868,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC if verified, ok := getGJSONBool(body, "email_verified"); ok { claims.EmailVerified = &verified } + claims.DisplayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + ) + claims.AvatarURL = firstNonEmpty( + getGJSON(body, "picture"), + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) claims.Email = strings.TrimSpace(claims.Email) claims.Username = strings.TrimSpace(claims.Username) claims.Subject = strings.TrimSpace(claims.Subject) + claims.DisplayName = strings.TrimSpace(claims.DisplayName) + claims.AvatarURL = strings.TrimSpace(claims.AvatarURL) return claims } @@ -595,7 +920,7 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall if strings.TrimSpace(nonce) != "" { q.Set("nonce", nonce) } - if cfg.UsePKCE { + if strings.TrimSpace(codeChallenge) != "" { q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") } @@ -831,14 +1156,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string { return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain } -func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string { - email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail)) - if email != "" { - return email - } - return oidcSyntheticEmailFromIdentityKey(identityKey) -} - func oidcFallbackUsername(subject string) string { subject = strings.TrimSpace(subject) if subject == "" { diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a161aa77cf6faefaebae46f6a570227263484c90..3216d51e7fdf3e85ca81987144d3c6c67c9bc4ba 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -12,7 +13,15 @@ import ( "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -30,26 +39,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) { require.Contains(t, e1, "@oidc-connect.invalid") } -func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) { - identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a") - - email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey) - require.Equal(t, "user@example.com", email) - - email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey) - require.Equal(t, "idtoken@example.com", email) - - email = oidcSelectLoginEmail("", "", identityKey) - require.Contains(t, email, "@oidc-connect.invalid") - require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email) -} - func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) { cfg := config.OIDCConnectConfig{ AuthorizeURL: "https://issuer.example.com/auth", ClientID: "cid", Scopes: "openid email profile", - UsePKCE: true, } u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback") @@ -106,6 +100,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) { require.Error(t, err) } +func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) { + cfg := config.OIDCConnectConfig{} + + claims := oidcParseUserInfo(`{ + "sub":"subject-1", + "preferred_username":"alice", + "name":"Alice Example", + "picture":"https://cdn.example/avatar.png", + "email":"alice@example.com", + "email_verified":true + }`, cfg) + + require.Equal(t, "subject-1", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Example", claims.DisplayName) + require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL) + require.NotNil(t, claims.EmailVerified) + require.True(t, *claims.EmailVerified) +} + func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) @@ -118,3 +132,909 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { E: e, } } + +func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { + handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/oauth/authorize", + TokenURL: "https://issuer.example.com/oauth/token", + UserInfoURL: "https://issuer.example.com/oauth/userinfo", + JWKSURL: "https://issuer.example.com/oauth/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil) + c.Request = req + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84}) + + handler.OIDCOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "issuer.example.com/oauth/authorize") + require.Contains(t, location, "client_id=oidc-client") + require.Contains(t, location, "nonce=") + + cookies := recorder.Result().Cookies() + require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName)) + require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie)) + require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie)) + require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName)) + + intentCookie := findCookie(cookies, oidcOAuthIntentCookieName) + require.NotNil(t, intentCookie) + require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value)) + + bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName) + require.NotNil(t, bindCookie) + userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret") + require.NoError(t, err) + require.Equal(t, int64(84), userID) +} + +func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) { + handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/oauth/authorize", + TokenURL: "https://issuer.example.com/oauth/token", + UserInfoURL: "https://issuer.example.com/oauth/userinfo", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + ValidateIDToken: false, + RequireEmailVerified: false, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil) + + handler.OIDCOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotContains(t, location, "code_challenge=") + require.NotContains(t, location, "nonce=") + require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie)) + require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie)) +} + +func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, r.ParseForm()) + require.Empty(t, r.PostForm.Get("code_verifier")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{ + Enabled: true, + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: false, + ValidateIDToken: false, + RequireEmailVerified: false, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) +} + +func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-login", + PreferredUsername: "oidc_login", + DisplayName: "OIDC Login Display", + AvatarURL: "https://cdn.example/oidc-login.png", + Email: "oidc-login@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-subject-login"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) + require.Nil(t, completion["error"]) +} + +func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-disabled-subject", + PreferredUsername: "oidc_disabled", + DisplayName: "OIDC Disabled", + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-disabled-subject"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-compat", + PreferredUsername: "oidc_compat", + DisplayName: "OIDC Compat Display", + AvatarURL: "https://cdn.example/oidc-compat.png", + Email: "legacy@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + +func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-unverified-compat", + PreferredUsername: "oidc_unverified", + DisplayName: "OIDC Unverified Compat Display", + AvatarURL: "https://cdn.example/oidc-unverified.png", + Email: "owner@example.com", + EmailVerified: false, + }) + defer cleanup() + cfg.RequireEmailVerified = true + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + _, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location")) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-invite", + PreferredUsername: "oidc_invite", + DisplayName: "OIDC Invite Display", + AvatarURL: "https://cdn.example/oidc-invite.png", + Email: "oidc-invite@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-bind", + PreferredUsername: "oidc_bind", + DisplayName: "OIDC Bind Display", + AvatarURL: "https://cdn.example/oidc-bind.png", + Email: "oidc-bind@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser)) + req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentBindCurrentUser, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, cfg.IssuerURL, session.ProviderKey) + require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) + + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Empty(t, completion["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, userCount) +} + +func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-1"). + SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Display", + "suggested_avatar_url": "https://cdn.example/oidc.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "OIDC Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "OIDC Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("oidc-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-choice-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-choice-subject-1"). + SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-no-adoption"). + SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Legacy", + "suggested_avatar_url": "https://cdn.example/oidc-legacy.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "oidc_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + +func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(existingOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-conflict-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-conflict-subject"). + SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-conflict-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusConflict, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"]) + + userCount, err := client.User.Query(). + Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +type oidcProviderFixture struct { + Subject string + PreferredUsername string + DisplayName string + AvatarURL string + Email string + EmailVerified bool +} + +func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler { + t.Helper() + handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) + return handler +} + +func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) { + t.Helper() + handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled) + handler.settingSvc = nil + handler.cfg = &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + OIDC: oauthCfg, + } + return handler, client +} + +func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "test-kid" + jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}} + tokenResponse := oidcTokenResponse{ + AccessToken: "oidc-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + userInfoPayload := map[string]any{ + "sub": fixture.Subject, + "preferred_username": fixture.PreferredUsername, + "name": fixture.DisplayName, + "picture": fixture.AvatarURL, + "email": fixture.Email, + "email_verified": fixture.EmailVerified, + } + + var issuer string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + require.NoError(t, json.NewEncoder(w).Encode(tokenResponse)) + case "/userinfo": + require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload)) + case "/jwks": + require.NoError(t, json.NewEncoder(w).Encode(jwks)) + default: + http.NotFound(w, r) + } + })) + + issuer = server.URL + now := time.Now() + claims := oidcIDTokenClaims{ + Email: fixture.Email, + EmailVerified: boolPtr(fixture.EmailVerified), + PreferredUsername: fixture.PreferredUsername, + Name: fixture.DisplayName, + Nonce: "nonce-" + fixture.Subject, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: fixture.Subject, + Audience: jwt.ClaimStrings{"oidc-client"}, + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)), + ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = kid + tokenResponse.IDToken, err = token.SignedString(privateKey) + require.NoError(t, err) + + cfg := config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "Test OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: issuer, + AuthorizeURL: issuer + "/authorize", + TokenURL: issuer + "/token", + UserInfoURL: issuer + "/userinfo", + JWKSURL: issuer + "/jwks", + Scopes: "openid profile email", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + RequireEmailVerified: false, + } + return cfg, server.Close +} diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1924cb812cd35ee9f438965e46c58a2e43c8f6d0 --- /dev/null +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -0,0 +1,61 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 29, + Email: "session@example.com", + Username: "session-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 7, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := &AuthHandler{authService: authService} + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29}) + + handler.RevokeAllSessions(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(8), repo.user.TokenVersion) + + var resp struct { + Code int `json:"code"` + Data struct { + Message string `json:"message"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message) +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..efee4cc01a4b04255e94900c350050b6004e1ae6 --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -0,0 +1,1349 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat" + wechatOAuthCookieMaxAgeSec = 10 * 60 + wechatOAuthStateCookieName = "wechat_oauth_state" + wechatOAuthRedirectCookieName = "wechat_oauth_redirect" + wechatOAuthIntentCookieName = "wechat_oauth_intent" + wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthBindUserCookieName = "wechat_oauth_bind_user" + wechatOAuthDefaultRedirectTo = "/dashboard" + wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" + wechatOAuthProviderKey = "wechat-main" + wechatOAuthLegacyProviderKey = "wechat" + wechatPaymentOAuthCookiePath = "/api/v1/auth/oauth/wechat/payment" + wechatPaymentOAuthStateName = "wechat_payment_oauth_state" + wechatPaymentOAuthRedirect = "wechat_payment_oauth_redirect" + wechatPaymentOAuthContextName = "wechat_payment_oauth_context" + wechatPaymentOAuthScope = "wechat_payment_oauth_scope" + wechatPaymentOAuthDefaultTo = "/purchase" + wechatPaymentOAuthFrontendCB = "/auth/wechat/payment/callback" + + wechatOAuthIntentLogin = "login" + wechatOAuthIntentBind = "bind_current_user" + wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email" +) + +var ( + wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token" + wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo" +) + +type wechatOAuthConfig struct { + mode string + appID string + appSecret string + authorizeURL string + scope string + redirectURI string + frontendCallback string + openEnabled bool + mpEnabled bool +} + +type wechatOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatOAuthUserInfoResponse struct { + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + HeadImgURL string `json:"headimgurl"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatPaymentOAuthContext struct { + PaymentType string `json:"payment_type"` + Amount string `json:"amount,omitempty"` + OrderType string `json:"order_type,omitempty"` + PlanID int64 `json:"plan_id,omitempty"` +} + +// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived +// browser cookies required by the rebuild pending-auth bridge. +func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + intent := normalizeWeChatOAuthIntent(c.Query("intent")) + secureCookie := isRequestHTTPS(c) + wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie) + } else { + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + } + + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, +// and stores the result in the unified pending-auth flow. +func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { + frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context()) + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + + intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName) + mode, err := readCookieDecoded(c, wechatOAuthModeCookieName) + if err != nil || strings.TrimSpace(mode) == "" { + redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "") + return + } + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error())) + return + } + + unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) + openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) + providerSubject := unionid + if providerSubject == "" { + if cfg.requiresUnionID() { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") + return + } + providerSubject = openid + } + if providerSubject == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") + return + } + + username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) + email := wechatSyntheticEmail(providerSubject) + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": providerSubject, + "openid": openid, + "unionid": unionid, + "mode": cfg.mode, + "channel": cfg.mode, + "channel_app_id": strings.TrimSpace(cfg.appID), + "channel_subject": openid, + "suggested_display_name": strings.TrimSpace(userInfo.Nickname), + "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), + } + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + } + + normalizedIntent := normalizeWeChatOAuthIntent(intent) + if normalizedIntent == wechatOAuthIntentBind { + if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil { + switch infraerrors.Code(err) { + case http.StatusConflict: + redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err)) + case http.StatusUnauthorized, http.StatusForbidden: + redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err)) + default: + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + } + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser == nil { + existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + } + if existingIdentityUser != nil { + if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createWeChatChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + "", + nil, + true, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createWeChatChoicePendingSession( + c, + identityRef, + email, + email, + redirectTo, + browserSessionKey, + upstreamClaims, + "", + nil, + false, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +// WeChatPaymentOAuthStart starts the WeChat payment OAuth flow. +// GET /api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=/purchase +func (h *AuthHandler) WeChatPaymentOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + paymentType := normalizeWeChatPaymentType(c.Query("payment_type")) + if paymentType == "" { + response.BadRequest(c, "Invalid payment type") + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(c.Query("redirect"))) + if redirectTo == "" { + redirectTo = wechatPaymentOAuthDefaultTo + } + rawContext, err := encodeWeChatPaymentOAuthContext(wechatPaymentOAuthContext{ + PaymentType: paymentType, + Amount: strings.TrimSpace(c.Query("amount")), + OrderType: strings.TrimSpace(c.Query("order_type")), + PlanID: parseWeChatPaymentPlanID(c.Query("plan_id")), + }) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONTEXT_ENCODE_FAILED", "failed to encode oauth context").WithCause(err)) + return + } + + scope := normalizeWeChatPaymentScope(c.Query("scope")) + secureCookie := isRequestHTTPS(c) + wechatPaymentSetCookie(c, wechatPaymentOAuthStateName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthRedirect, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthContextName, encodeCookieValue(rawContext), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatPaymentSetCookie(c, wechatPaymentOAuthScope, encodeCookieValue(scope), wechatOAuthCookieMaxAgeSec, secureCookie) + + cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c) + cfg.scope = scope + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatPaymentOAuthCallback exchanges a payment OAuth code for an OpenID and +// forwards the browser back to the frontend callback route. +func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) { + frontendCallback := wechatPaymentOAuthFrontendCB + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatPaymentOAuthStateName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatPaymentOAuthRedirect) + redirectTo = normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(redirectTo)) + if redirectTo == "" { + redirectTo = wechatPaymentOAuthDefaultTo + } + + rawContext, _ := readCookieDecoded(c, wechatPaymentOAuthContextName) + paymentContext, err := decodeWeChatPaymentOAuthContext(rawContext) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_context", "invalid oauth context", "") + return + } + if paymentContext.PaymentType == "" { + paymentContext.PaymentType = payment.TypeWxpay + } + + scope, _ := readCookieDecoded(c, wechatPaymentOAuthScope) + scope = normalizeWeChatPaymentScope(scope) + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c) + tokenResp, err := exchangeWeChatOAuthCode(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", err.Error()) + return + } + + openid := strings.TrimSpace(tokenResp.OpenID) + if openid == "" { + redirectOAuthError(c, frontendCallback, "missing_openid", "missing openid", "") + return + } + if strings.TrimSpace(tokenResp.Scope) != "" { + scope = strings.TrimSpace(tokenResp.Scope) + } + + resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{ + OpenID: openid, + PaymentType: paymentContext.PaymentType, + Amount: paymentContext.Amount, + OrderType: paymentContext.OrderType, + PlanID: paymentContext.PlanID, + RedirectTo: redirectTo, + Scope: scope, + }) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "") + return + } + + fragment := url.Values{} + fragment.Set("wechat_resume_token", resumeToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService { + var legacyKey []byte + key, err := payment.ProvideEncryptionKey(h.cfg) + if err == nil { + legacyKey = []byte(key) + } + return service.NewLegacyAwarePaymentResumeService(legacyKey) +} + +type completeWeChatOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by +// validating the invitation code and consuming the current pending browser session. +// POST /api/v1/auth/oauth/wechat/complete-registration +func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { + var req completeWeChatOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) createWeChatPendingSession( + c *gin.Context, + intent string, + providerSubject string, + email string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + tokenPair *service.TokenPair, + authErr error, + targetUserID *int64, +) error { + completionResponse := map[string]any{ + "redirect": redirectTo, + } + if authErr != nil { + if errors.Is(authErr, service.ErrOAuthInvitationRequired) { + completionResponse["error"] = "invitation_required" + } else { + return authErr + } + } else if tokenPair != nil { + completionResponse["access_token"] = tokenPair.AccessToken + completionResponse["refresh_token"] = tokenPair.RefreshToken + completionResponse["expires_in"] = tokenPair.ExpiresIn + completionResponse["token_type"] = "Bearer" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: intent, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + }, + TargetUserID: targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) createWeChatChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + } + if forceEmailOnSignup { + completionResponse["choice_reason"] = "force_email_on_signup" + } + + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) createWeChatBindPendingSession( + c *gin.Context, + cfg wechatOAuthConfig, + providerSubject string, + channelSubject string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName) + if err != nil { + return err + } + if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil { + return err + } + return h.createWeChatPendingSession( + c, + wechatOAuthIntentBind, + providerSubject, + currentUser.Email, + redirectTo, + browserSessionKey, + upstreamClaims, + nil, + nil, + ¤tUser.ID, + ) +} + +func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName) + if err != nil { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + userEntity, err := client.User.Get(c.Request.Context(), userID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account") + } + return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) ensureWeChatBindOwnership( + ctx context.Context, + userID int64, + providerSubject string, + cfg wechatOAuthConfig, + channelSubject string, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), + authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), + ). + All(ctx) + if err != nil { + return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) + } + for _, identity := range identities { + if identity != nil && identity.UserID != userID { + activeOwner, lookupErr := findActiveUserByID(ctx, client, identity.UserID) + if lookupErr != nil { + return lookupErr + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + } + + channelSubject = strings.TrimSpace(channelSubject) + channelAppID := strings.TrimSpace(cfg.appID) + if channelSubject == "" || channelAppID == "" { + return nil + } + + channels, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), + authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) + } + for _, channel := range channels { + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + activeOwner, lookupErr := findActiveUserByID(ctx, client, channel.Edges.Identity.UserID) + if lookupErr != nil { + return lookupErr + } + if activeOwner != nil { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + } + return nil +} + +func (h *AuthHandler) findWeChatUserByLegacyOpenID( + ctx context.Context, + identity service.PendingAuthIdentityKey, + cfg wechatOAuthConfig, + openid string, +) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + providerType := strings.TrimSpace(identity.ProviderType) + providerSubject := strings.TrimSpace(identity.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey) + if providerSubject != "" { + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) + } + } + + openid = strings.TrimSpace(openid) + channel := strings.TrimSpace(cfg.mode) + channelAppID := strings.TrimSpace(cfg.appID) + if openid != "" && channel != "" && channelAppID != "" { + records, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(openid), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if user, err := singleWeChatChannelUser(records); err != nil || user != nil { + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) + } + } + + if openid == "" { + return nil, nil + } + + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(openid), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + user, err := singleWeChatIdentityUser(records) + if err != nil || user == nil { + return user, err + } + return findActiveUserByID(ctx, client, user.ID) +} + +func wechatCompatibleProviderKeys(providerKey string) []string { + preferred := strings.TrimSpace(providerKey) + if preferred == "" { + preferred = wechatOAuthProviderKey + } + keys := []string{preferred} + if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) { + keys = append(keys, wechatOAuthLegacyProviderKey) + } + return keys +} + +func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.User + continue + } + if resolved.ID != record.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + return resolved, nil +} + +func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.Identity.Edges.User + continue + } + if resolved.ID != record.Edges.Identity.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + return resolved, nil +} + +func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( + ctx context.Context, + userID int64, + identity service.PendingAuthIdentityKey, + upstreamClaims map[string]any, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + tx, err := client.Tx(ctx) + if err != nil { + return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims), + }, userID) + if err != nil { + return err + } + return tx.Commit() +} + +func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { + mode, err := resolveWeChatOAuthMode(rawMode, c) + if err != nil { + return wechatOAuthConfig{}, err + } + + if h == nil || h.settingSvc == nil { + return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready") + } + + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + settings, err := h.settingSvc.GetAllSettings(ctx) + if err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + + effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err != nil { + return wechatOAuthConfig{}, err + } + if !effective.SupportsMode(mode) { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + + cfg := wechatOAuthConfig{ + mode: mode, + appID: strings.TrimSpace(effective.AppIDForMode(mode)), + appSecret: strings.TrimSpace(effective.AppSecretForMode(mode)), + redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")), + frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB), + scope: effective.ScopeForMode(mode), + openEnabled: effective.OpenEnabled, + mpEnabled: effective.MPEnabled, + } + + switch mode { + case "mp": + cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" + default: + cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" + } + if strings.TrimSpace(cfg.redirectURI) == "" { + return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + + return cfg, nil +} + +func (cfg wechatOAuthConfig) requiresUnionID() bool { + return cfg.openEnabled && cfg.mpEnabled +} + +func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string { + if h != nil && h.settingSvc != nil { + cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" { + return strings.TrimSpace(cfg.FrontendRedirectURL) + } + } + return wechatOAuthDefaultFrontendCB +} + +func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { + mode := strings.ToLower(strings.TrimSpace(rawMode)) + if mode == "" { + if isWeChatBrowserRequest(c) { + return "mp", nil + } + return "open", nil + } + if mode != "open" && mode != "mp" { + return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp") + } + return mode, nil +} + +func isWeChatBrowserRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger") +} + +func normalizeWeChatOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "login": + return wechatOAuthIntentLogin + case "bind", "bind_current_user": + return wechatOAuthIntentBind + case "adopt", "adopt_existing_user_by_email": + return wechatOAuthIntentAdoptEmail + default: + return wechatOAuthIntentLogin + } +} + +func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) { + u, err := url.Parse(cfg.authorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize url: %w", err) + } + query := u.Query() + query.Set("appid", cfg.appID) + query.Set("redirect_uri", cfg.redirectURI) + query.Set("response_type", "code") + query.Set("scope", cfg.scope) + query.Set("state", state) + u.RawQuery = query.Encode() + u.Fragment = "wechat_redirect" + return u.String(), nil +} + +func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string { + callbackPath = strings.TrimSpace(callbackPath) + if callbackPath == "" { + return "" + } + + if raw := strings.TrimSpace(apiBaseURL); raw != "" { + if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" { + basePath := strings.TrimRight(parsed.EscapedPath(), "/") + targetPath := callbackPath + if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") { + targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1") + } else if basePath != "" { + targetPath = basePath + callbackPath + } + return parsed.Scheme + "://" + parsed.Host + targetPath + } + } + + if c == nil || c.Request == nil { + return "" + } + scheme := "http" + if isRequestHTTPS(c) { + scheme = "https" + } + host := strings.TrimSpace(c.Request.Host) + if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + if host == "" { + return "" + } + return scheme + "://" + host + callbackPath +} + +func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) { + tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code) + if err != nil { + return nil, nil, err + } + userInfo, err := fetchWeChatUserInfo(ctx, tokenResp) + if err != nil { + return nil, nil, err + } + return tokenResp, userInfo, nil +} + +func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) { + endpoint, err := url.Parse(wechatOAuthAccessTokenURL) + if err != nil { + return nil, fmt.Errorf("parse wechat access token url: %w", err) + } + + query := endpoint.Query() + query.Set("appid", cfg.appID) + query.Set("secret", cfg.appSecret) + query.Set("code", strings.TrimSpace(code)) + query.Set("grant_type", "authorization_code") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat access token request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat access token: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat access token response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode) + } + + var tokenResp wechatOAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode wechat access token response: %w", err) + } + if tokenResp.ErrCode != 0 { + return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg)) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, fmt.Errorf("wechat access token missing access_token") + } + return &tokenResp, nil +} + +func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) { + if tokenResp == nil { + return nil, fmt.Errorf("wechat token response is nil") + } + + endpoint, err := url.Parse(wechatOAuthUserInfoURL) + if err != nil { + return nil, fmt.Errorf("parse wechat userinfo url: %w", err) + } + query := endpoint.Query() + query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken)) + query.Set("openid", strings.TrimSpace(tokenResp.OpenID)) + query.Set("lang", "zh_CN") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat userinfo request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat userinfo: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat userinfo response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode) + } + + var userInfo wechatOAuthUserInfoResponse + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, fmt.Errorf("decode wechat userinfo response: %w", err) + } + if userInfo.ErrCode != 0 { + return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg)) + } + return &userInfo, nil +} + +func wechatSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain +} + +func wechatFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "wechat_user" + } + return "wechat_" + truncateFragmentValue(subject) +} + +func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func normalizeWeChatPaymentType(raw string) string { + switch strings.TrimSpace(raw) { + case payment.TypeWxpay, payment.TypeWxpayDirect: + return strings.TrimSpace(raw) + default: + return "" + } +} + +func normalizeWeChatPaymentScope(raw string) string { + for _, part := range strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool { + return r == ',' || r == ' ' || r == '\t' || r == '\n' || r == '\r' + }) { + switch strings.TrimSpace(part) { + case "snsapi_userinfo": + return "snsapi_userinfo" + case "snsapi_base": + return "snsapi_base" + } + } + return "snsapi_base" +} + +func normalizeWeChatPaymentRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return wechatPaymentOAuthDefaultTo + } + if path == "/payment" { + return "/purchase" + } + if strings.HasPrefix(path, "/payment?") { + return "/purchase" + strings.TrimPrefix(path, "/payment") + } + return path +} + +func (h *AuthHandler) resolveWeChatPaymentOAuthCallbackURL(ctx context.Context, c *gin.Context) string { + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + if settings, err := h.settingSvc.GetAllSettings(ctx); err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + return resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/payment/callback") +} + +func encodeWeChatPaymentOAuthContext(ctx wechatPaymentOAuthContext) (string, error) { + data, err := json.Marshal(ctx) + if err != nil { + return "", err + } + return string(data), nil +} + +func decodeWeChatPaymentOAuthContext(raw string) (wechatPaymentOAuthContext, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return wechatPaymentOAuthContext{}, nil + } + var ctx wechatPaymentOAuthContext + if err := json.Unmarshal([]byte(raw), &ctx); err != nil { + return wechatPaymentOAuthContext{}, err + } + return ctx, nil +} + +func parseWeChatPaymentPlanID(raw string) int64 { + id, _ := strconv.ParseInt(strings.TrimSpace(raw), 10, 64) + return id +} + +func wechatPaymentSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatPaymentOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatPaymentClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatPaymentOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7cf114c1c10cfe20aafe7016c50a55e89ab0e4c1 --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -0,0 +1,1497 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-open-app", + service.SettingKeyWeChatConnectAppSecret: "wx-open-secret", + service.SettingKeyWeChatConnectMode: "open", + service.SettingKeyWeChatConnectScopes: "snsapi_login", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "appid=wx-open-app") + require.Contains(t, location, "scope=snsapi_login") + + cookies := recorder.Result().Cookies() + require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) + require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) +} + +func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-shared-app", + service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectOpenEnabled: "true", + service.SettingKeyWeChatConnectMPEnabled: "true", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "connect/qrconnect") + require.Contains(t, location, "scope=snsapi_login") +} + +func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat", session.ProviderType) + require.Equal(t, "wechat-main", session.ProviderKey) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) + require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) +} + +func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Equal(t, "openid-123", session.ProviderSubject) + require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) +} + +func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(wechatSyntheticEmail("union-456")). + SetUsername("wechat-existing-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"username": "wechat-existing-user"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthIntentLogin, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) + _, hasRefreshToken := completion["refresh_token"] + require.False(t, hasRefreshToken) +} + +func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail(wechatSyntheticEmail("union-disabled")). + SetUsername("disabled-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusDisabled). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-disabled"). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`)) + return + } + http.NotFound(w, r) + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) + defer client.Close() + handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + handler.cfg.Totp.EncryptionKeyConfigured = true + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123")) + req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat")) + req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`)) + req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base")) + c.Request = req + + handler.WeChatPaymentOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + parsed, err := url.Parse(location) + require.NoError(t, err) + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect")) + require.NotEmpty(t, fragment.Get("wechat_resume_token")) + require.Empty(t, fragment.Get("openid")) + require.Empty(t, fragment.Get("payment_type")) + require.Empty(t, fragment.Get("amount")) + require.Empty(t, fragment.Get("order_type")) + require.Empty(t, fragment.Get("plan_id")) + + claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token")) + require.NoError(t, err) + require.Equal(t, "openid-123", claims.OpenID) + require.Equal(t, payment.TypeWxpay, claims.PaymentType) + require.Equal(t, "12.5", claims.Amount) + require.Equal(t, payment.OrderTypeSubscription, claims.OrderType) + require.EqualValues(t, 7, claims.PlanID) + require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) +} + +func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`)) + return + } + http.NotFound(w, r) + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) + defer client.Close() + + legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + explicitSigningKey := "explicit-payment-resume-signing-key" + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey) + handler.cfg.Totp.EncryptionKey = legacyKeyHex + handler.cfg.Totp.EncryptionKeyConfigured = true + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed")) + req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat")) + req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`)) + req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base")) + c.Request = req + + handler.WeChatPaymentOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + parsed, err := url.Parse(location) + require.NoError(t, err) + fragment, err := url.ParseQuery(parsed.Fragment) + require.NoError(t, err) + + token := fragment.Get("wechat_resume_token") + require.NotEmpty(t, token) + + claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token) + require.NoError(t, err) + require.Equal(t, "openid-mixed-key", claims.OpenID) + require.Equal(t, payment.TypeWxpay, claims.PaymentType) + require.Equal(t, "18.8", claims.Amount) + require.Equal(t, payment.OrderTypeSubscription, claims.OrderType) + require.EqualValues(t, 9, claims.PlanID) + require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) + + _, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token) + require.Error(t, err) +} + +func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { + testCases := []struct { + name string + mode string + appID string + appSecret string + openID string + }{ + { + name: "open", + mode: "open", + appID: "wx-open-app", + appSecret: "wx-open-secret", + openID: "openid-open-123", + }, + { + name: "mp", + mode: "mp", + appID: "wx-mp-app", + appSecret: "wx-mp-secret", + openID: "openid-mp-123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback")) + defer client.Close() + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(context.Background()) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode)) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) + require.NoError(t, err) + require.Equal(t, wechatOAuthIntentBind, session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, currentUser.ID, *session.TargetUserID) + require.Equal(t, currentUser.Email, session.ResolvedEmail) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"]) + require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"]) + require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"]) + require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"]) + + completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completionResponse["redirect"]) + _, hasAccessToken := completionResponse["access_token"] + require.False(t, hasAccessToken) + }) + } +} + +func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + ownerIdentity, err := client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("union-owner"). + SetMetadata(map[string]any{"unionid": "union-owner"}). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentityChannel.Create(). + SetIdentityID(ownerIdentity.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetChannel("open"). + SetChannelAppID("wx-open-app"). + SetChannelSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, true) + defer client.Close() + + ctx := context.Background() + redeemRepo := repository.NewRedeemCodeRepository(client) + require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ + Code: "invite-1", + Type: service.RedeemTypeInvitation, + Status: service.StatusUnused, + })) + + callbackRecorder := httptest.NewRecorder() + callbackCtx, _ := gin.CreateTestContext(callbackRecorder) + callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + callbackReq.Host = "api.example.com" + callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + callbackCtx.Request = callbackReq + + handler.WeChatOAuthCallback(callbackCtx) + + require.Equal(t, http.StatusFound, callbackRecorder.Code) + require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) + + sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) + + pendingSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"]) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) + completeRecorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(completeRecorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, completeRecorder.Code) + responseData := decodeJSONBody(t, completeRecorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["adoption_required"]) + require.Empty(t, responseData["access_token"]) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, consumed.ConsumedAt) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, channelCount) + + decisionCount, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). + Count(ctx) + require.NoError(t, err) + require.Zero(t, decisionCount) +} + +func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-no-adoption-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("wechat-subject-no-adoption"). + SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-browser-no-adoption"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + "suggested_display_name": "WeChat Legacy", + "suggested_avatar_url": "https://cdn.example/wechat-legacy.png", + "mode": "open", + "channel": "open", + "channel_app_id": "wx-open-app", + "channel_subject": "openid-legacy", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + require.NotEmpty(t, responseData["refresh_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat_user", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.False(t, decision.AdoptDisplayName) + require.False(t, decision.AdoptAvatar) +} + +func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + openIDIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, openIDIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + +func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-invalid-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("wechat-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-choice-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("wechat-choice-subject-1"). + SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + legacyIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, legacyIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + +func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil) +} + +func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string { + return map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: appID, + service.SettingKeyWeChatConnectAppSecret: secret, + service.SettingKeyWeChatConnectMode: mode, + service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode), + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect, + } +} + +func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + userRepo := &oauthPendingFlowUserRepo{client: client} + redeemRepo := repository.NewRedeemCodeRepository(client) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + values := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + } + for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") { + values[key] = value + } + for key, value := range extraSettings { + values[key] = value + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg) + + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &wechatOAuthRefreshTokenCacheStub{}, + cfg, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + cfg: cfg, + }, client +} + +type wechatOAuthSettingRepoStub struct { + values map[string]string +} + +func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type wechatOAuthRefreshTokenCacheStub struct{} + +func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go new file mode 100644 index 0000000000000000000000000000000000000000..8982b80defc9b47e124e16f0721866949bd0d1d6 --- /dev/null +++ b/backend/internal/handler/available_channel_handler.go @@ -0,0 +1,283 @@ +package handler + +import ( + "sort" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AvailableChannelHandler 处理用户侧「可用渠道」查询。 +// +// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤: +// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道; +// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些; +// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型, +// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问 +// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏; +// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels +// / 内部 ID / Status 等管理字段)。 +type AvailableChannelHandler struct { + channelService *service.ChannelService + apiKeyService *service.APIKeyService + settingService *service.SettingService +} + +// NewAvailableChannelHandler 创建用户侧可用渠道 handler。 +func NewAvailableChannelHandler( + channelService *service.ChannelService, + apiKeyService *service.APIKeyService, + settingService *service.SettingService, +) *AvailableChannelHandler { + return &AvailableChannelHandler{ + channelService: channelService, + apiKeyService: apiKeyService, + settingService: settingService, + } +} + +// featureEnabled 返回 available-channels 开关是否启用。默认关闭(opt-in)。 +func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool { + if h.settingService == nil { + return false + } + return h.settingService.GetAvailableChannelsRuntime(c.Request.Context()).Enabled +} + +// userAvailableGroup 用户可见的分组概要(白名单字段)。 +// +// 前端据此区分专属 vs 公开分组(IsExclusive)、订阅 vs 标准分组(SubscriptionType, +// 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走 +// /groups/rates,和 API 密钥页面保持一致。 +type userAvailableGroup struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` +} + +// userSupportedModelPricing 用户可见的定价字段白名单。 +type userSupportedModelPricing struct { + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []userPricingIntervalDTO `json:"intervals"` +} + +// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。 +type userPricingIntervalDTO struct { + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` +} + +// userSupportedModel 用户可见的支持模型条目。 +type userSupportedModel struct { + Name string `json:"name"` + Platform string `json:"platform"` + Pricing *userSupportedModelPricing `json:"pricing"` +} + +// userChannelPlatformSection 单渠道内某个平台的子视图:用户可见的分组 + 该平台 +// 支持的模型。按 platform 聚合后让前端可以把渠道名作为 row-group 一次渲染, +// 后面的平台行按 sections 顺序铺开。 +type userChannelPlatformSection struct { + Platform string `json:"platform"` + Groups []userAvailableGroup `json:"groups"` + SupportedModels []userSupportedModel `json:"supported_models"` +} + +// userAvailableChannel 用户可见的渠道条目(白名单字段)。 +// +// 每个渠道聚合为一条记录,内嵌 platforms 子数组:每个 section 对应一个平台, +// 包含该平台的 groups 和 supported_models。 +type userAvailableChannel struct { + Name string `json:"name"` + Description string `json:"description"` + Platforms []userChannelPlatformSection `json:"platforms"` +} + +// List 列出当前用户可见的「可用渠道」。 +// GET /api/v1/channels/available +func (h *AvailableChannelHandler) List(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + // Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后, + // 保持与未开关前的 401 行为一致:未登录先 401,登录后再按开关决定。 + if !h.featureEnabled(c) { + response.Success(c, []userAvailableChannel{}) + return + } + + userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + allowedGroupIDs := make(map[int64]struct{}, len(userGroups)) + for i := range userGroups { + allowedGroupIDs[userGroups[i].ID] = struct{}{} + } + + channels, err := h.channelService.ListAvailable(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]userAvailableChannel, 0, len(channels)) + for _, ch := range channels { + if ch.Status != service.StatusActive { + continue + } + visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs) + if len(visibleGroups) == 0 { + continue + } + sections := buildPlatformSections(ch, visibleGroups) + if len(sections) == 0 { + continue + } + out = append(out, userAvailableChannel{ + Name: ch.Name, + Description: ch.Description, + Platforms: sections, + }) + } + + response.Success(c, out) +} + +// buildPlatformSections 把一个渠道按 visibleGroups 的平台集合拆成有序的 section 列表: +// 每个 section 对应一个平台,只包含该平台的 groups 和 supported_models。 +// 输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。 +func buildPlatformSections( + ch service.AvailableChannel, + visibleGroups []userAvailableGroup, +) []userChannelPlatformSection { + groupsByPlatform := make(map[string][]userAvailableGroup, 4) + for _, g := range visibleGroups { + if g.Platform == "" { + continue + } + groupsByPlatform[g.Platform] = append(groupsByPlatform[g.Platform], g) + } + if len(groupsByPlatform) == 0 { + return nil + } + + platforms := make([]string, 0, len(groupsByPlatform)) + for p := range groupsByPlatform { + platforms = append(platforms, p) + } + sort.Strings(platforms) + + sections := make([]userChannelPlatformSection, 0, len(platforms)) + for _, platform := range platforms { + platformSet := map[string]struct{}{platform: {}} + sections = append(sections, userChannelPlatformSection{ + Platform: platform, + Groups: groupsByPlatform[platform], + SupportedModels: toUserSupportedModels(ch.SupportedModels, platformSet), + }) + } + return sections +} + +// filterUserVisibleGroups 仅保留用户可访问的分组。 +func filterUserVisibleGroups( + groups []service.AvailableGroupRef, + allowed map[int64]struct{}, +) []userAvailableGroup { + visible := make([]userAvailableGroup, 0, len(groups)) + for _, g := range groups { + if _, ok := allowed[g.ID]; !ok { + continue + } + visible = append(visible, userAvailableGroup{ + ID: g.ID, + Name: g.Name, + Platform: g.Platform, + SubscriptionType: g.SubscriptionType, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + }) + } + return visible +} + +// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。 +// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。 +// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。 +func toUserSupportedModels( + src []service.SupportedModel, + allowedPlatforms map[string]struct{}, +) []userSupportedModel { + out := make([]userSupportedModel, 0, len(src)) + for i := range src { + m := src[i] + if allowedPlatforms != nil { + if _, ok := allowedPlatforms[m.Platform]; !ok { + continue + } + } + out = append(out, userSupportedModel{ + Name: m.Name, + Platform: m.Platform, + Pricing: toUserPricing(m.Pricing), + }) + } + return out +} + +// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。 +func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing { + if p == nil { + return nil + } + intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals)) + for _, iv := range p.Intervals { + intervals = append(intervals, userPricingIntervalDTO{ + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + }) + } + billingMode := string(p.BillingMode) + if billingMode == "" { + billingMode = string(service.BillingModeToken) + } + return &userSupportedModelPricing{ + BillingMode: billingMode, + InputPrice: p.InputPrice, + OutputPrice: p.OutputPrice, + CacheWritePrice: p.CacheWritePrice, + CacheReadPrice: p.CacheReadPrice, + ImageOutputPrice: p.ImageOutputPrice, + PerRequestPrice: p.PerRequestPrice, + Intervals: intervals, + } +} diff --git a/backend/internal/handler/available_channel_handler_test.go b/backend/internal/handler/available_channel_handler_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0a7ce6c466a5cd0f0853ba6806b5c43d6d37ddaf --- /dev/null +++ b/backend/internal/handler/available_channel_handler_test.go @@ -0,0 +1,157 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestUserAvailableChannel_Unauthenticated401(t *testing.T) { + // 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。 + gin.SetMode(gin.TestMode) + h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们 + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil) + + h.List(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) { + // 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。 + groups := []service.AvailableGroupRef{ + {ID: 1, Name: "g1", Platform: "anthropic"}, + {ID: 2, Name: "g2", Platform: "anthropic"}, + {ID: 3, Name: "g3", Platform: "openai"}, + } + allowed := map[int64]struct{}{1: {}, 3: {}} + + visible := filterUserVisibleGroups(groups, allowed) + require.Len(t, visible, 2) + ids := []int64{visible[0].ID, visible[1].ID} + require.ElementsMatch(t, []int64{1, 3}, ids) +} + +func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) { + // 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。 + src := []service.SupportedModel{ + {Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil}, + {Name: "gpt-4o", Platform: "openai", Pricing: nil}, + } + allowed := map[string]struct{}{"anthropic": {}} + out := toUserSupportedModels(src, allowed) + require.Len(t, out, 1) + require.Equal(t, "claude-sonnet-4-6", out[0].Name) +} + +func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) { + // 显式传 nil allowedPlatforms 表示不做过滤。 + src := []service.SupportedModel{ + {Name: "a", Platform: "anthropic"}, + {Name: "b", Platform: "openai"}, + } + require.Len(t, toUserSupportedModels(src, nil), 2) +} + +func TestUserAvailableChannel_FieldWhitelist(t *testing.T) { + // 通过序列化 userAvailableChannel 结构体验证响应形状: + // 只有 name / description / platforms;不含管理端字段。 + row := userAvailableChannel{ + Name: "ch", + Description: "d", + Platforms: []userChannelPlatformSection{ + { + Platform: "anthropic", + Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}}, + SupportedModels: []userSupportedModel{}, + }, + }, + } + raw, err := json.Marshal(row) + require.NoError(t, err) + var decoded map[string]any + require.NoError(t, json.Unmarshal(raw, &decoded)) + + for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} { + _, exists := decoded[key] + require.Falsef(t, exists, "user DTO must not expose %q", key) + } + for _, key := range []string{"name", "description", "platforms"} { + _, exists := decoded[key] + require.Truef(t, exists, "user DTO must expose %q", key) + } + + // 验证 section 的字段(platform / groups / supported_models)。 + rawSection, err := json.Marshal(row.Platforms[0]) + require.NoError(t, err) + var sectionDecoded map[string]any + require.NoError(t, json.Unmarshal(rawSection, §ionDecoded)) + for _, key := range []string{"platform", "groups", "supported_models"} { + _, exists := sectionDecoded[key] + require.Truef(t, exists, "platform section must expose %q", key) + } + + // Group DTO 暴露区分专属/公开、订阅类型、默认倍率所需的字段, + // 前端据此渲染 GroupBadge 并与 API 密钥页保持一致的视觉。 + rawGroup, err := json.Marshal(row.Platforms[0].Groups[0]) + require.NoError(t, err) + var groupDecoded map[string]any + require.NoError(t, json.Unmarshal(rawGroup, &groupDecoded)) + for _, key := range []string{"id", "name", "platform", "subscription_type", "rate_multiplier", "is_exclusive"} { + _, exists := groupDecoded[key] + require.Truef(t, exists, "group DTO must expose %q", key) + } + + // pricing interval 白名单:不应暴露 id / sort_order。 + pricing := toUserPricing(&service.ChannelModelPricing{ + BillingMode: service.BillingModeToken, + Intervals: []service.PricingInterval{ + {ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3}, + }, + }) + require.NotNil(t, pricing) + require.Len(t, pricing.Intervals, 1) + rawIv, err := json.Marshal(pricing.Intervals[0]) + require.NoError(t, err) + var ivDecoded map[string]any + require.NoError(t, json.Unmarshal(rawIv, &ivDecoded)) + for _, key := range []string{"id", "pricing_id", "sort_order"} { + _, exists := ivDecoded[key] + require.Falsef(t, exists, "user pricing interval must not expose %q", key) + } +} + +func TestBuildPlatformSections_GroupsByPlatform(t *testing.T) { + // 一个渠道横跨 anthropic / openai / 空平台:应该生成 2 个 section, + // 按 platform 字母序排序,各自 groups 和 supported_models 只含同平台条目。 + ch := service.AvailableChannel{ + Name: "ch", + SupportedModels: []service.SupportedModel{ + {Name: "claude-sonnet-4-6", Platform: "anthropic"}, + {Name: "gpt-4o", Platform: "openai"}, + }, + } + visible := []userAvailableGroup{ + {ID: 1, Name: "g-openai", Platform: "openai"}, + {ID: 2, Name: "g-ant", Platform: "anthropic"}, + {ID: 3, Name: "g-empty", Platform: ""}, + } + sections := buildPlatformSections(ch, visible) + require.Len(t, sections, 2) + require.Equal(t, "anthropic", sections[0].Platform) + require.Equal(t, "openai", sections[1].Platform) + require.Len(t, sections[0].Groups, 1) + require.Equal(t, int64(2), sections[0].Groups[0].ID) + require.Len(t, sections[0].SupportedModels, 1) + require.Equal(t, "claude-sonnet-4-6", sections[0].SupportedModels[0].Name) +} diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go new file mode 100644 index 0000000000000000000000000000000000000000..cc36b3346118bc6a35b4ae726b1038264bb53c31 --- /dev/null +++ b/backend/internal/handler/channel_monitor_user_handler.go @@ -0,0 +1,176 @@ +package handler + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler/admin" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ChannelMonitorUserHandler 渠道监控用户只读 handler。 +type ChannelMonitorUserHandler struct { + monitorService *service.ChannelMonitorService + settingService *service.SettingService +} + +// NewChannelMonitorUserHandler 创建 handler。 +// settingService 用于每次请求前读取功能开关;关闭时 List/GetStatus 直接返回空/404。 +func NewChannelMonitorUserHandler( + monitorService *service.ChannelMonitorService, + settingService *service.SettingService, +) *ChannelMonitorUserHandler { + return &ChannelMonitorUserHandler{ + monitorService: monitorService, + settingService: settingService, + } +} + +// featureEnabled 返回当前渠道监控功能是否开启。 +// settingService 为 nil(测试场景)视为启用。 +func (h *ChannelMonitorUserHandler) featureEnabled(c *gin.Context) bool { + if h.settingService == nil { + return true + } + return h.settingService.GetChannelMonitorRuntime(c.Request.Context()).Enabled +} + +// --- Response --- + +type channelMonitorUserListItem struct { + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + GroupName string `json:"group_name"` + PrimaryModel string `json:"primary_model"` + PrimaryStatus string `json:"primary_status"` + PrimaryLatencyMs *int `json:"primary_latency_ms"` + PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"` + Availability7d float64 `json:"availability_7d"` + ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"` + Timeline []channelMonitorUserTimelinePoint `json:"timeline"` +} + +// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。 +// 仅用于用户视图 list 响应,admin 视图不使用。 +type channelMonitorUserTimelinePoint struct { + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + CheckedAt string `json:"checked_at"` +} + +type channelMonitorUserDetailResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + GroupName string `json:"group_name"` + Models []channelMonitorUserModelStat `json:"models"` +} + +type channelMonitorUserModelStat struct { + Model string `json:"model"` + LatestStatus string `json:"latest_status"` + LatestLatencyMs *int `json:"latest_latency_ms"` + Availability7d float64 `json:"availability_7d"` + Availability15d float64 `json:"availability_15d"` + Availability30d float64 `json:"availability_30d"` + AvgLatency7dMs *int `json:"avg_latency_7d_ms"` +} + +func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListItem { + extras := make([]dto.ChannelMonitorExtraModelStatus, 0, len(v.ExtraModels)) + for _, e := range v.ExtraModels { + extras = append(extras, dto.ChannelMonitorExtraModelStatus{ + Model: e.Model, + Status: e.Status, + LatencyMs: e.LatencyMs, + }) + } + timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline)) + for _, p := range v.Timeline { + timeline = append(timeline, channelMonitorUserTimelinePoint{ + Status: p.Status, + LatencyMs: p.LatencyMs, + PingLatencyMs: p.PingLatencyMs, + CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339), + }) + } + return channelMonitorUserListItem{ + ID: v.ID, + Name: v.Name, + Provider: v.Provider, + GroupName: v.GroupName, + PrimaryModel: v.PrimaryModel, + PrimaryStatus: v.PrimaryStatus, + PrimaryLatencyMs: v.PrimaryLatencyMs, + PrimaryPingLatencyMs: v.PrimaryPingLatencyMs, + Availability7d: v.Availability7d, + ExtraModels: extras, + Timeline: timeline, + } +} + +func userMonitorDetailToResponse(d *service.UserMonitorDetail) *channelMonitorUserDetailResponse { + models := make([]channelMonitorUserModelStat, 0, len(d.Models)) + for _, m := range d.Models { + models = append(models, channelMonitorUserModelStat{ + Model: m.Model, + LatestStatus: m.LatestStatus, + LatestLatencyMs: m.LatestLatencyMs, + Availability7d: m.Availability7d, + Availability15d: m.Availability15d, + Availability30d: m.Availability30d, + AvgLatency7dMs: m.AvgLatency7dMs, + }) + } + return &channelMonitorUserDetailResponse{ + ID: d.ID, + Name: d.Name, + Provider: d.Provider, + GroupName: d.GroupName, + Models: models, + } +} + +// --- Handlers --- + +// List GET /api/v1/channel-monitors +func (h *ChannelMonitorUserHandler) List(c *gin.Context) { + if !h.featureEnabled(c) { + response.Success(c, gin.H{"items": []channelMonitorUserListItem{}}) + return + } + views, err := h.monitorService.ListUserView(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + items := make([]channelMonitorUserListItem, 0, len(views)) + for _, v := range views { + items = append(items, userMonitorViewToItem(v)) + } + response.Success(c, gin.H{"items": items}) +} + +// GetStatus GET /api/v1/channel-monitors/:id/status +func (h *ChannelMonitorUserHandler) GetStatus(c *gin.Context) { + if !h.featureEnabled(c) { + response.ErrorFrom(c, service.ErrChannelMonitorNotFound) + return + } + // 复用 admin.ParseChannelMonitorID 保持错误码与日志一致。 + id, ok := admin.ParseChannelMonitorID(c) + if !ok { + return + } + detail, err := h.monitorService.GetUserDetail(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, userMonitorDetailToResponse(detail)) +} diff --git a/backend/internal/handler/dto/channel_monitor.go b/backend/internal/handler/dto/channel_monitor.go new file mode 100644 index 0000000000000000000000000000000000000000..3c0c5e119a37fcc2ace66553e9322f313e724c8e --- /dev/null +++ b/backend/internal/handler/dto/channel_monitor.go @@ -0,0 +1,10 @@ +package dto + +// ChannelMonitorExtraModelStatus 渠道监控附加模型最近一次状态。 +// 同时被 admin handler(List 响应)与 user handler(List 响应)复用, +// 字段必须保持一致以保证前端拿到统一结构。 +type ChannelMonitorExtraModelStatus struct { + Model string `json:"model"` + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d2ccb8d62f827e9b1a7d92a5d90679aedc31df4f..f7503c2ea15b56a0dc66e75b4ece245a91e16e6d 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -21,6 +21,7 @@ func UserFromServiceShallow(u *service.User) *User { Concurrency: u.Concurrency, Status: u.Status, AllowedGroups: u.AllowedGroups, + LastActiveAt: u.LastActiveAt, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, BalanceNotifyEnabled: u.BalanceNotifyEnabled, @@ -28,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User { BalanceNotifyThreshold: u.BalanceNotifyThreshold, BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails), TotalRecharged: u.TotalRecharged, + RPMLimit: u.RPMLimit, } } @@ -66,6 +68,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return &AdminUser{ User: *base, Notes: u.Notes, + LastUsedAt: u.LastUsedAt, GroupRates: u.GroupRates, } } @@ -182,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group { AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, RequirePrivacySet: g.RequirePrivacySet, + RPMLimit: g.RPMLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/handler/dto/public_settings_injection_schema_test.go b/backend/internal/handler/dto/public_settings_injection_schema_test.go new file mode 100644 index 0000000000000000000000000000000000000000..428fed3d8ae543f7d74aedc0c968b9703624fce3 --- /dev/null +++ b/backend/internal/handler/dto/public_settings_injection_schema_test.go @@ -0,0 +1,70 @@ +package dto + +import ( + "reflect" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TestPublicSettingsInjectionPayload_SchemaDoesNotDrift guarantees the SSR +// injection struct exposes every JSON field consumed by the frontend. +// +// Why this test exists: before we extracted a named PublicSettingsInjectionPayload +// type, the inline struct was manually kept in sync with dto.PublicSettings and +// drifted — ChannelMonitorEnabled / AvailableChannelsEnabled were missing, which +// made the frontend read `undefined` on refresh and hide the "可用渠道" menu +// until the async /api/v1/settings/public round-trip finished. +// +// This test compares the two JSON-tag sets and fails if injection is missing +// any field that dto.PublicSettings exposes. Adding a new feature flag with +// only a DTO entry will fail this test until the injection struct is updated. +// +// Intentional exclusions (fields present on dto.PublicSettings that SSR does +// not need to inject) are listed in `dtoOnlyFields` below with a reason. +func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) { + injection := jsonTags(reflect.TypeOf(service.PublicSettingsInjectionPayload{})) + dtoKeys := jsonTags(reflect.TypeOf(PublicSettings{})) + + // Fields that legitimately live only on the DTO. Keep tiny; document each. + dtoOnlyFields := map[string]string{ + // sora_client_enabled is an upstream-only field the fork does not surface. + "sora_client_enabled": "upstream-only field, not used on this fork", + // force_email_on_third_party_signup lives on the DTO but is not injected via SSR. + "force_email_on_third_party_signup": "auth-source default, not a feature flag", + } + + var missing []string + for key := range dtoKeys { + if _, ok := injection[key]; ok { + continue + } + if _, allowed := dtoOnlyFields[key]; allowed { + continue + } + missing = append(missing, key) + } + if len(missing) > 0 { + t.Fatalf("service.PublicSettingsInjectionPayload is missing JSON fields present on dto.PublicSettings: %s\n"+ + "add the field to PublicSettingsInjectionPayload (and GetPublicSettingsForInjection), or "+ + "document the exclusion in dtoOnlyFields with a reason.", strings.Join(missing, ", ")) + } +} + +func jsonTags(t reflect.Type) map[string]struct{} { + out := make(map[string]struct{}) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + tag := f.Tag.Get("json") + if tag == "" || tag == "-" { + continue + } + name := strings.SplitN(tag, ",", 2)[0] + if name == "" { + continue + } + out[name] = struct{}{} + } + return out +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 3659e79be3ed9a0aeab4dc2838ec25957d5b026f..2affbc4611f7aa7c19dca8404d2c26ccfe0c3411 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,6 +51,23 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` + WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` + WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"` + WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` + WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"` + WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` + WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"` + WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` + WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` + WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` OIDCConnectClientID string `json:"oidc_connect_client_id"` @@ -91,6 +108,7 @@ type SystemSettings struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration @@ -127,6 +145,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"` + // Payment configuration PaymentEnabled bool `json:"payment_enabled"` PaymentMinAmount float64 `json:"payment_min_amount"` @@ -157,6 +184,13 @@ type SystemSettings struct { BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` + + // Channel Monitor feature switch + ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` + + // Available Channels feature switch (user-facing aggregate view) + AvailableChannelsEnabled bool `json:"available_channels_enabled"` } type DefaultSubscriptionSetting struct { @@ -167,6 +201,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` @@ -189,6 +224,10 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` SoraClientEnabled bool `json:"sora_client_enabled"` @@ -199,6 +238,11 @@ type PublicSettings struct { AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + + ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` + + AvailableChannelsEnabled bool `json:"available_channels_enabled"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8c1e166f3d69dba16ebd90e2cb7571f679593fd8..5cc2f8e4dd49a325d874c51f1363acb6991cacdd 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -7,16 +7,17 @@ import ( ) type User struct { - ID int64 `json:"id"` - Email string `json:"email"` - Username string `json:"username"` - Role string `json:"role"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - Status string `json:"status"` - AllowedGroups []int64 `json:"allowed_groups"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + Email string `json:"email"` + Username string `json:"username"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + Status string `json:"status"` + AllowedGroups []int64 `json:"allowed_groups"` + LastActiveAt *time.Time `json:"last_active_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` // 余额不足通知 BalanceNotifyEnabled bool `json:"balance_notify_enabled"` @@ -25,6 +26,9 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` TotalRecharged float64 `json:"total_recharged"` + // RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。 + RPMLimit int `json:"rpm_limit"` + APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` } @@ -34,7 +38,8 @@ type User struct { type AdminUser struct { User - Notes string `json:"notes"` + Notes string `json:"notes"` + LastUsedAt *time.Time `json:"last_used_at"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier GroupRates map[int64]float64 `json:"group_rates,omitempty"` @@ -106,6 +111,9 @@ type Group struct { RequireOAuthOnly bool `json:"require_oauth_only"` RequirePrivacySet bool `json:"require_privacy_set"` + // RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。 + RPMLimit int `json:"rpm_limit"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a17f0ce486549d0243f4f30e7f0cbd5eea0dfbb4 --- /dev/null +++ b/backend/internal/handler/dto/user_mapper_activity_test.go @@ -0,0 +1,33 @@ +package dto + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) { + t.Parallel() + + lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC) + lastActiveAt := lastLoginAt.Add(15 * time.Minute) + lastUsedAt := lastLoginAt.Add(45 * time.Minute) + + out := UserFromServiceAdmin(&service.User{ + ID: 42, + Email: "admin@example.com", + Username: "admin", + Role: service.RoleAdmin, + Status: service.StatusActive, + LastActiveAt: &lastActiveAt, + LastUsedAt: &lastUsedAt, + }) + + require.NotNil(t, out) + require.NotNil(t, out.LastActiveAt) + require.NotNil(t, out.LastUsedAt) + require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second) + require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second) +} diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index a897bc40541f28d7555e6a06b9f93be30c262c99..db29618aef2bf1110667510329213320c7bf0ac1 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -15,10 +15,12 @@ import ( // ────────────────────────────────────────────────────────── const ( - EndpointMessages = "/v1/messages" - EndpointChatCompletions = "/v1/chat/completions" - EndpointResponses = "/v1/responses" - EndpointGeminiModels = "/v1beta/models" + EndpointMessages = "/v1/messages" + EndpointChatCompletions = "/v1/chat/completions" + EndpointResponses = "/v1/responses" + EndpointImagesGenerations = "/v1/images/generations" + EndpointImagesEdits = "/v1/images/edits" + EndpointGeminiModels = "/v1beta/models" ) // gin.Context keys used by the middleware and helpers below. @@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string { return EndpointChatCompletions case strings.Contains(path, EndpointMessages): return EndpointMessages + case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"): + return EndpointImagesGenerations + case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"): + return EndpointImagesEdits case strings.Contains(path, EndpointResponses): return EndpointResponses case strings.Contains(path, EndpointGeminiModels): @@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { switch platform { case service.PlatformOpenAI: + if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { + return inbound + } // OpenAI forwards everything to the Responses API. // Preserve subresource suffix (e.g. /v1/responses/compact). if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" { diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index 1519bc9e62196f6dab2e47d3e23dc46dd8603d60..369c5fa79d43224c9c83851e466917f085581a8d 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) { {"/v1/messages", EndpointMessages}, {"/v1/chat/completions", EndpointChatCompletions}, {"/v1/responses", EndpointResponses}, + {"/v1/images/generations", EndpointImagesGenerations}, + {"/v1/images/edits", EndpointImagesEdits}, {"/v1beta/models", EndpointGeminiModels}, // Prefixed paths (antigravity, openai). {"/antigravity/v1/messages", EndpointMessages}, {"/openai/v1/responses", EndpointResponses}, {"/openai/v1/responses/compact", EndpointResponses}, + {"/openai/v1/images/generations", EndpointImagesGenerations}, + {"/openai/v1/images/edits", EndpointImagesEdits}, {"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels}, // Gin route patterns with wildcards. @@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { {"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"}, {"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses}, {"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses}, + {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations}, + {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits}, // Antigravity — uses inbound to pick Claude vs Gemini upstream. {"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages}, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index f5eff8c9a517f3d9a4b33e9bacf799225591053e..ef53255973f61bf76e57a68f5d98d7895172c35c 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -301,6 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { + reqLog.Warn("gateway.select_account_no_available", + zap.String("model", reqModel), + zap.Int64p("group_id", apiKey.GroupID), + zap.String("platform", platform), + zap.Error(err), + ) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } @@ -344,6 +353,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + reqLog.Warn("gateway.select_account_no_slot_no_wait_plan", + zap.Int64("account_id", account.ID), + zap.String("model", reqModel), + zap.String("platform", platform), + ) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } @@ -525,6 +539,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { + reqLog.Warn("gateway.select_account_no_available", + zap.String("model", reqModel), + zap.Int64p("group_id", currentAPIKey.GroupID), + zap.String("platform", platform), + zap.Bool("fallback_used", fallbackUsed), + zap.Error(err), + ) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } @@ -568,6 +589,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + reqLog.Warn("gateway.select_account_no_slot_no_wait_plan", + zap.Int64("account_id", account.ID), + zap.String("model", reqModel), + zap.String("platform", platform), + ) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } @@ -735,7 +761,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -1441,7 +1470,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.errorResponse(c, status, code, message) return } @@ -1684,25 +1716,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter c.JSON(http.StatusOK, response) } -func billingErrorDetails(err error) (status int, code, message string) { +func billingErrorDetails(err error) (status int, code, message string, retryAfter int) { if errors.Is(err, service.ErrBillingServiceUnavailable) { msg := pkgerrors.Message(err) if msg == "" { msg = "Billing service temporarily unavailable. Please retry later." } - return http.StatusServiceUnavailable, "billing_service_error", msg + return http.StatusServiceUnavailable, "billing_service_error", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 + } + // 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。 + // 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。 + if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) { + msg := pkgerrors.Message(err) + retrySeconds := 60 - int(time.Now().Unix()%60) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds } msg := pkgerrors.Message(err) if msg == "" { @@ -1712,7 +1751,7 @@ func billingErrorDetails(err error) (status int, code, message string) { ).Warn("gateway.billing_error_missing_message") msg = "Billing error" } - return http.StatusForbidden, "billing_error", msg + return http.StatusForbidden, "billing_error", msg, 0 } func (h *GatewayHandler) metadataBridgeEnabled() bool { diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e8a888029039618b485c2b26950c1ef7d43b247e --- /dev/null +++ b/backend/internal/handler/gateway_handler_billing_error_test.go @@ -0,0 +1,54 @@ +package handler + +import ( + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) { + // 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。 + for _, err := range []error{ + service.ErrAPIKeyRateLimit5hExceeded, + service.ErrAPIKeyRateLimit1dExceeded, + service.ErrAPIKeyRateLimit7dExceeded, + } { + status, code, _, _ := billingErrorDetails(err) + require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err) + require.Equal(t, "rate_limit_exceeded", code) + } +} + +func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) { + status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable) + require.Equal(t, http.StatusServiceUnavailable, status) + require.Equal(t, "billing_service_error", code) + require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After") +} + +func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) { + status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance) + require.Equal(t, http.StatusForbidden, status) + require.Equal(t, "billing_error", code) + require.NotEmpty(t, msg) +} diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index be267332d88a17d2c6650e6660795aaadcde8000..4290e54bfbbfb4899a5ae4dba51fdb1ea21bd6a4 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.chatCompletionsErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e908eb9e0fa821a2f66571a62baa88b6834b1dad..683cf2b7719594a06d46eea1d8f8362958cca92b 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.responsesErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 1fdc46ba28ac1602f8711e99bd6509e1b14ea0fe..71030140e6787e91be0325cbcd7f194a65ef2c65 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d200c17ce51f96c5a4ab1e37969c517e5c7c3095..2a34e3f079d000b9d13a2c279dbc138288b5da3b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -9,6 +9,7 @@ import ( "errors" "net/http" "regexp" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) - status, _, message := billingErrorDetails(err) + status, _, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } googleError(c, status, message) return } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 906a74f148fef7a37080500a84a6db2a541984d7..aee9d927dd9757da8e008c566a222f8caf2c3b20 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -6,50 +6,54 @@ import ( // AdminHandlers contains all admin-related HTTP handlers type AdminHandlers struct { - Dashboard *admin.DashboardHandler - User *admin.UserHandler - Group *admin.GroupHandler - Account *admin.AccountHandler - Announcement *admin.AnnouncementHandler - DataManagement *admin.DataManagementHandler - Backup *admin.BackupHandler - OAuth *admin.OAuthHandler - OpenAIOAuth *admin.OpenAIOAuthHandler - GeminiOAuth *admin.GeminiOAuthHandler - AntigravityOAuth *admin.AntigravityOAuthHandler - Proxy *admin.ProxyHandler - Redeem *admin.RedeemHandler - Promo *admin.PromoHandler - Setting *admin.SettingHandler - Ops *admin.OpsHandler - System *admin.SystemHandler - Subscription *admin.SubscriptionHandler - Usage *admin.UsageHandler - UserAttribute *admin.UserAttributeHandler - ErrorPassthrough *admin.ErrorPassthroughHandler - TLSFingerprintProfile *admin.TLSFingerprintProfileHandler - APIKey *admin.AdminAPIKeyHandler - ScheduledTest *admin.ScheduledTestHandler - Channel *admin.ChannelHandler - Payment *admin.PaymentHandler + Dashboard *admin.DashboardHandler + User *admin.UserHandler + Group *admin.GroupHandler + Account *admin.AccountHandler + Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler + Backup *admin.BackupHandler + OAuth *admin.OAuthHandler + OpenAIOAuth *admin.OpenAIOAuthHandler + GeminiOAuth *admin.GeminiOAuthHandler + AntigravityOAuth *admin.AntigravityOAuthHandler + Proxy *admin.ProxyHandler + Redeem *admin.RedeemHandler + Promo *admin.PromoHandler + Setting *admin.SettingHandler + Ops *admin.OpsHandler + System *admin.SystemHandler + Subscription *admin.SubscriptionHandler + Usage *admin.UsageHandler + UserAttribute *admin.UserAttributeHandler + ErrorPassthrough *admin.ErrorPassthroughHandler + TLSFingerprintProfile *admin.TLSFingerprintProfileHandler + APIKey *admin.AdminAPIKeyHandler + ScheduledTest *admin.ScheduledTestHandler + Channel *admin.ChannelHandler + ChannelMonitor *admin.ChannelMonitorHandler + ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler + Payment *admin.PaymentHandler } // Handlers contains all HTTP handlers type Handlers struct { - Auth *AuthHandler - User *UserHandler - APIKey *APIKeyHandler - Usage *UsageHandler - Redeem *RedeemHandler - Subscription *SubscriptionHandler - Announcement *AnnouncementHandler - Admin *AdminHandlers - Gateway *GatewayHandler - OpenAIGateway *OpenAIGatewayHandler - Setting *SettingHandler - Totp *TotpHandler - Payment *PaymentHandler - PaymentWebhook *PaymentWebhookHandler + Auth *AuthHandler + User *UserHandler + APIKey *APIKeyHandler + Usage *UsageHandler + Redeem *RedeemHandler + Subscription *SubscriptionHandler + Announcement *AnnouncementHandler + ChannelMonitor *ChannelMonitorUserHandler + Admin *AdminHandlers + Gateway *GatewayHandler + OpenAIGateway *OpenAIGatewayHandler + Setting *SettingHandler + Totp *TotpHandler + Payment *PaymentHandler + PaymentWebhook *PaymentWebhookHandler + AvailableChannel *AvailableChannelHandler } // BuildInfo contains build-time information diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 991cbb91f704aa5d35e05177f272bf6553454cca..3c4e62515b5d64edb12fe04b4cc905df331e7d78 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 5319b55d9dddf85a13200e749ee00a15128f1de1..1c97557382a2a69b20d6730d1d04cfc65a30b8f7 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") return } + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_requires_wsv2"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2") + return } setOpsRequestContext(c, reqModel, reqStream, body) @@ -223,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -589,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.anthropicStreamingAwareError(c, status, code, message, streamStarted) return } @@ -856,7 +867,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_call_id"), ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2") return false } if validation.HasItemReferenceForAllCallIDs { @@ -866,7 +877,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_item_reference"), ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2") return false } diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index d299fb81e338120a49052581b8dd00ac813c6ab1..8ecee59ae79d47e980d24b83824369c530b98d76 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") } +func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Responses WebSocket v2") + require.Contains(t, w.Body.String(), "previous_response_id") +} + +func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Responses WebSocket v2") + require.NotContains(t, w.Body.String(), "reuse previous_response_id") +} + func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go new file mode 100644 index 0000000000000000000000000000000000000000..403b41ef6eae4274bc45ae5127a562c79a93065b --- /dev/null +++ b/backend/internal/handler/openai_images.go @@ -0,0 +1,304 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "strconv" + "strings" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Images handles OpenAI Images API requests. +// POST /v1/images/generations +// POST /v1/images/edits +func (h *OpenAIGatewayHandler) Images(c *gin.Context) { + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.images", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if isMultipartImagesContentType(c.GetHeader("Content-Type")) { + setOpsRequestContext(c, "", false, nil) + } else { + setOpsRequestContext(c, "", false, body) + } + + parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return + } + + reqLog = reqLog.With( + zap.String("model", parsed.Model), + zap.Bool("stream", parsed.Stream), + zap.Bool("multipart", parsed.Multipart), + zap.String("capability", string(parsed.RequiredCapability)), + ) + + if parsed.Multipart { + setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) + } else { + setOpsRequestContext(c, parsed.Model, parsed.Stream, body) + } + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false))) + + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := "" + if parsed.Multipart { + sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed()) + } else { + sessionHash = h.gatewayService.GenerateSessionHash(c, body) + } + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages( + c.Request.Context(), + apiKey.GroupID, + sessionHash, + parsed.Model, + failedAccountIDs, + parsed.RequiredCapability, + ) + if err != nil { + reqLog.Warn("openai.images.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) + return + } + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } + return + } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) + return + } + + reqLog.Debug("openai.images.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) + + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.images.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.images.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.images.forward_failed", fields...) + return + } + reqLog.Error("openai.images.forward_failed", fields...) + return + } + + if result != nil { + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + if parsed.Multipart { + requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) + } + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.images"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", parsed.Model), + zap.Int64("account_id", account.ID), + ).Error("openai.images.record_usage_failed", zap.Error(err)) + } + }) + + reqLog.Debug("openai.images.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} + +func isMultipartImagesContentType(contentType string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data") +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 90e90dd074e68793ab6e4ca4c64a23ffc335201f..935549121bbb0804986b0b6e22bac0954cb63609 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string { return service.PlatformAntigravity case strings.HasPrefix(p, "/v1beta/"): return service.PlatformGemini - case strings.Contains(p, "/responses"): + case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"): return service.PlatformOpenAI default: return "" diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1ddb8ae2dcfdeb2418d8822d1beace3cac85f80d..09580442d74ca8e123a040eebbb4f466c752f4d4 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -1,9 +1,14 @@ package handler import ( + "fmt" "strconv" "strings" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -202,10 +207,18 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) { // CreateOrderRequest is the request body for creating a payment order. type CreateOrderRequest struct { - Amount float64 `json:"amount"` - PaymentType string `json:"payment_type" binding:"required"` - OrderType string `json:"order_type"` - PlanID int64 `json:"plan_id"` + Amount float64 `json:"amount"` + PaymentType string `json:"payment_type" binding:"required"` + OpenID string `json:"openid"` + WechatResumeToken string `json:"wechat_resume_token"` + ReturnURL string `json:"return_url"` + PaymentSource string `json:"payment_source"` + OrderType string `json:"order_type"` + PlanID int64 `json:"plan_id"` + // IsMobile lets the frontend declare its mobile status directly. When + // nil we fall back to User-Agent heuristics (which miss iPadOS / some + // embedded browsers that strip the "Mobile" keyword). + IsMobile *bool `json:"is_mobile,omitempty"` } // CreateOrder creates a new payment order. @@ -221,17 +234,36 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { response.BadRequest(c, "Invalid request: "+err.Error()) return } + if strings.TrimSpace(req.WechatResumeToken) != "" { + claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil { + response.ErrorFrom(c, err) + return + } + } + mobile := isMobile(c) + if req.IsMobile != nil { + mobile = *req.IsMobile + } result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ - UserID: subject.UserID, - Amount: req.Amount, - PaymentType: req.PaymentType, - ClientIP: c.ClientIP(), - IsMobile: isMobile(c), - SrcHost: c.Request.Host, - SrcURL: c.Request.Referer(), - OrderType: req.OrderType, - PlanID: req.PlanID, + UserID: subject.UserID, + Amount: req.Amount, + PaymentType: req.PaymentType, + OpenID: req.OpenID, + ClientIP: c.ClientIP(), + IsMobile: mobile, + IsWeChatBrowser: isWeChatBrowser(c), + SrcHost: c.Request.Host, + SrcURL: c.Request.Referer(), + ReturnURL: req.ReturnURL, + PaymentSource: req.PaymentSource, + OrderType: req.OrderType, + PlanID: req.PlanID, }) if err != nil { response.ErrorFrom(c, err) @@ -240,6 +272,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { response.Success(c, result) } +func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error { + if req == nil || claims == nil { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing") + } + openid := strings.TrimSpace(claims.OpenID) + if openid == "" { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid") + } + + paymentType := service.NormalizeVisibleMethod(claims.PaymentType) + if paymentType == "" { + paymentType = payment.TypeWxpay + } + if req.PaymentType != "" { + requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType) + if requestPaymentType != "" && requestPaymentType != paymentType { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch") + } + } + req.PaymentType = paymentType + req.OpenID = openid + + if strings.TrimSpace(claims.Amount) != "" { + amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64) + if err != nil || amount <= 0 { + return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount)) + } + req.Amount = amount + } + if claims.OrderType != "" { + req.OrderType = claims.OrderType + } + if claims.PlanID > 0 { + req.PlanID = claims.PlanID + } + return nil +} + // GetMyOrders returns the authenticated user's orders. // GET /api/v1/payment/orders/my func (h *PaymentHandler) GetMyOrders(c *gin.Context) { @@ -260,7 +330,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Paginated(c, orders, int64(total), page, pageSize) + response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize) } // GetOrder returns a single order for the authenticated user. @@ -282,7 +352,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, order) + response.Success(c, sanitizePaymentOrderForResponse(order)) } // CancelOrder cancels a pending order for the authenticated user. @@ -354,6 +424,10 @@ type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` } +type ResolveOrderByResumeTokenRequest struct { + ResumeToken string `json:"resume_token" binding:"required"` +} + // VerifyOrder actively queries the upstream payment provider to check // if payment was made, and processes it if so. // POST /api/v1/payment/orders/verify @@ -374,23 +448,57 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, order) + response.Success(c, sanitizePaymentOrderForResponse(order)) } // PublicOrderResult is the limited order info returned by the public verify endpoint. // No user details are exposed — only payment status information. type PublicOrderResult struct { - ID int64 `json:"id"` - OutTradeNo string `json:"out_trade_no"` - Amount float64 `json:"amount"` - PayAmount float64 `json:"pay_amount"` - PaymentType string `json:"payment_type"` - OrderType string `json:"order_type"` - Status string `json:"status"` + ID int64 `json:"id"` + OutTradeNo string `json:"out_trade_no"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + PaymentType string `json:"payment_type"` + OrderType string `json:"order_type"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + PaidAt *time.Time `json:"paid_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + RefundAmount float64 `json:"refund_amount"` + RefundReason *string `json:"refund_reason,omitempty"` + RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"` + RefundRequestedBy *string `json:"refund_requested_by,omitempty"` + RefundRequestReason *string `json:"refund_request_reason,omitempty"` + PlanID *int64 `json:"plan_id,omitempty"` } -// VerifyOrderPublic verifies payment status without requiring authentication. -// Returns limited order info (no user details) to prevent information leakage. +func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult { + return PublicOrderResult{ + ID: order.ID, + OutTradeNo: order.OutTradeNo, + Amount: order.Amount, + PayAmount: order.PayAmount, + FeeRate: order.FeeRate, + PaymentType: order.PaymentType, + OrderType: order.OrderType, + Status: order.Status, + CreatedAt: order.CreatedAt, + ExpiresAt: order.ExpiresAt, + PaidAt: order.PaidAt, + CompletedAt: order.CompletedAt, + RefundAmount: order.RefundAmount, + RefundReason: order.RefundReason, + RefundRequestedAt: order.RefundRequestedAt, + RefundRequestedBy: order.RefundRequestedBy, + RefundRequestReason: order.RefundRequestReason, + PlanID: order.PlanID, + } +} + +// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as +// a compatibility path for older result pages and staggered deploys. // POST /api/v1/payment/public/orders/verify func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { var req VerifyOrderRequest @@ -398,20 +506,30 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { response.BadRequest(c, "Invalid request: "+err.Error()) return } + order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, PublicOrderResult{ - ID: order.ID, - OutTradeNo: order.OutTradeNo, - Amount: order.Amount, - PayAmount: order.PayAmount, - PaymentType: order.PaymentType, - OrderType: order.OrderType, - Status: order.Status, - }) + response.Success(c, buildPublicOrderResult(order)) +} + +// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token. +// POST /api/v1/payment/public/orders/resolve +func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) { + var req ResolveOrderByResumeTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, buildPublicOrderResult(order)) } // requireAuth extracts the authenticated subject from the context. @@ -435,3 +553,27 @@ func isMobile(c *gin.Context) bool { } return false } + +func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder { + if len(orders) == 0 { + return orders + } + out := make([]*dbent.PaymentOrder, 0, len(orders)) + for _, order := range orders { + out = append(out, sanitizePaymentOrderForResponse(order)) + } + return out +} + +func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder { + if order == nil { + return nil + } + cloned := *order + cloned.ProviderSnapshot = nil + return &cloned +} + +func isWeChatBrowser(c *gin.Context) bool { + return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger") +} diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a7bc4ba3a92698e5c98a27eb7019274611895969 --- /dev/null +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -0,0 +1,368 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestApplyWeChatPaymentResumeClaims(t *testing.T) { + t.Parallel() + + req := CreateOrderRequest{ + Amount: 0, + PaymentType: payment.TypeWxpay, + OrderType: payment.OrderTypeBalance, + } + + err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeSubscription, + PlanID: 7, + }) + if err != nil { + t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err) + } + if req.OpenID != "openid-123" { + t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123") + } + if req.Amount != 12.5 { + t.Fatalf("amount = %v, want 12.5", req.Amount) + } + if req.OrderType != payment.OrderTypeSubscription { + t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription) + } + if req.PlanID != 7 { + t.Fatalf("plan_id = %d, want 7", req.PlanID) + } +} + +func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) { + t.Parallel() + + req := CreateOrderRequest{ + PaymentType: payment.TypeAlipay, + } + + err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeBalance, + }) + if err == nil { + t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types") + } +} + +func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-verify@example.com"). + SetPasswordHash("hash"). + SetUsername("public-verify-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(90.64). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-VERIFY"). + SetOutTradeNo("legacy-order-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-verify"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/verify", + bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.VerifyOrderPublic(ctx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ID int64 `json:"id"` + OutTradeNo string `json:"out_trade_no"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + PaymentType string `json:"payment_type"` + OrderType string `json:"order_type"` + Status string `json:"status"` + RefundAmount float64 `json:"refund_amount"` + CreatedAt string `json:"created_at"` + ExpiresAt string `json:"expires_at"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, order.ID, resp.Data.ID) + require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo) + require.Equal(t, 90.64, resp.Data.PayAmount) + require.Equal(t, 0.03, resp.Data.FeeRate) + require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType) + require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType) + require.Equal(t, service.OrderStatusPending, resp.Data.Status) + require.Equal(t, 0.0, resp.Data.RefundAmount) + require.NotEmpty(t, resp.Data.CreatedAt) + require.NotEmpty(t, resp.Data.ExpiresAt) +} + +func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + + db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-resolve@example.com"). + SetPasswordHash("hash"). + SetUsername("public-resolve-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(100). + SetPayAmount(103). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-RESOLVE"). + SetOutTradeNo("resolve-order-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-resolve"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/resolve", + bytes.NewBufferString(`{"resume_token":"`+token+`"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.ResolveOrderPublicByResumeToken(ctx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, float64(order.ID), resp.Data["id"]) + require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"]) + require.Equal(t, 100.0, resp.Data["amount"]) + require.Equal(t, 103.0, resp.Data["pay_amount"]) + require.Equal(t, 0.03, resp.Data["fee_rate"]) + require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"]) + require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"]) + require.Equal(t, service.OrderStatusPaid, resp.Data["status"]) + require.Contains(t, resp.Data, "created_at") + require.Contains(t, resp.Data, "expires_at") + require.Contains(t, resp.Data, "refund_amount") +} + +func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + + db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-resolve-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("public-resolve-mismatch-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(100). + SetPayAmount(103). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-RESOLVE-MISMATCH"). + SetOutTradeNo("resolve-order-mismatch-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-resolve-mismatch"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID + 999, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/resolve", + bytes.NewBufferString(`{"resume_token":"`+token+`"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.ResolveOrderPublicByResumeToken(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + Message string `json:"message"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason) +} + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/verify", + bytes.NewBufferString(`{"out_trade_no":" "}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.VerifyOrderPublic(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason) +} diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfebb72830b9957c83f4cf419814ffbfc2a0..9ae799fd3459945bd1bd5c4e9670bea4cd6c6b4b 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -1,6 +1,9 @@ package handler import ( + "context" + "errors" + "fmt" "io" "log/slog" "net/http" @@ -77,9 +80,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts). outTradeNo := extractOutTradeNo(rawBody, providerKey) - provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo) + providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo) if err != nil { slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err) + if providerKey == payment.TypeWxpay { + c.String(http.StatusBadRequest, "verify failed") + return + } writeSuccessResponse(c, providerKey) return } @@ -89,7 +96,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) headers[strings.ToLower(k)] = c.GetHeader(k) } - notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers) + resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers) if err != nil { truncatedBody := rawBody if len(truncatedBody) > webhookLogTruncateLen { @@ -103,24 +110,38 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // nil notification means irrelevant event (e.g. Stripe non-payment event); return success. if notification == nil { - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) return } - if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil { - slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err) + if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil { + // Unknown order: ack with 2xx so the provider stops retrying. This + // guards against foreign environments whose webhook endpoints are + // (mis)configured to point at us — without a 2xx, the provider will + // retry for days and spam our error logs. We still emit a WARN so the + // event is discoverable in logs. + if errors.Is(err, service.ErrOrderNotFound) { + slog.Warn("[Payment Webhook] unknown order, acking to stop retries", + "provider", resolvedProviderKey, + "outTradeNo", notification.OrderID, + "tradeNo", notification.TradeNo, + ) + writeSuccessResponse(c, resolvedProviderKey) + return + } + slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err) c.String(http.StatusInternalServerError, "handle failed") return } - writeSuccessResponse(c, providerKey) + writeSuccessResponse(c, resolvedProviderKey) } // extractOutTradeNo parses the webhook body to find the out_trade_no. // This allows looking up the correct provider instance before verification. func extractOutTradeNo(rawBody, providerKey string) string { switch providerKey { - case payment.TypeEasyPay: + case payment.TypeEasyPay, payment.TypeAlipay: values, err := url.ParseQuery(rawBody) if err == nil { return values.Get("out_trade_no") @@ -131,6 +152,25 @@ func extractOutTradeNo(rawBody, providerKey string) string { return "" } +func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) { + var lastErr error + for _, provider := range providers { + if provider == nil { + continue + } + notification, err := provider.VerifyNotification(ctx, rawBody, headers) + if err != nil { + lastErr = err + continue + } + return provider.ProviderKey(), notification, nil + } + if lastErr != nil { + return "", nil, lastErr + } + return "", nil, fmt.Errorf("no webhook provider could verify notification") +} + // wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook. type wxpaySuccessResponse struct { Code string `json:"code"` diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index bdef1766d91108a7ab4656bfc29715667a3a40aa..7551fc83ddb429f65ecb031fe0f825833e2bdafc 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -3,11 +3,16 @@ package handler import ( + "context" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -88,6 +93,43 @@ func TestWriteSuccessResponse(t *testing.T) { } } +// TestUnknownOrderWebhookAcksWithSuccess exercises the response contract that +// handleNotify relies on when HandlePaymentNotification returns ErrOrderNotFound: +// we still need to emit the provider-specific 2xx so the provider stops +// retrying. We can't easily drive handleNotify end-to-end without mocking the +// concrete *service.PaymentService, so this test locks down the two ingredients +// the fix depends on: +// 1. errors.Is recognises the sentinel through fmt.Errorf %w wrapping (which +// is how service layer wraps it with the out_trade_no context). +// 2. writeSuccessResponse produces the provider-specific body for Stripe +// (empty 200) — matching what handleNotify calls on the ack path. +// +// If either contract breaks, the Stripe "unknown order → 500 loop" regresses. +func TestUnknownOrderWebhookAcksWithSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 1) Sentinel recognition through wrapping. + wrapped := fmt.Errorf("%w: out_trade_no=sub2_missing_42", service.ErrOrderNotFound) + require.True(t, errors.Is(wrapped, service.ErrOrderNotFound), + "handleNotify uses errors.Is on the wrapped service error; regression here "+ + "would mean unknown-order webhooks go back to returning 500 and looping forever") + + // A distinct error must NOT match — otherwise a DB failure would be silently + // swallowed as an ack. + other := errors.New("lookup order failed: connection refused") + require.False(t, errors.Is(other, service.ErrOrderNotFound)) + + // 2) Provider-specific success body is what handleNotify emits on the + // ack path. Asserted again here because this is the shape Stripe expects + // to consider the webhook acknowledged. + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + writeSuccessResponse(c, payment.TypeStripe) + require.Equal(t, http.StatusOK, w.Code, + "Stripe requires 2xx to stop retrying; anything else restarts the retry loop") + require.Empty(t, w.Body.String(), "Stripe expects an empty body on the ack path") +} + func TestWebhookConstants(t *testing.T) { t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) { assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize)) @@ -97,3 +139,104 @@ func TestWebhookConstants(t *testing.T) { assert.Equal(t, 200, webhookLogTruncateLen) }) } + +func TestExtractOutTradeNo(t *testing.T) { + tests := []struct { + name string + providerKey string + rawBody string + want string + }{ + { + name: "easypay query payload", + providerKey: "easypay", + rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS", + want: "sub2_123", + }, + { + name: "alipay query payload", + providerKey: "alipay", + rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456", + want: "sub2_456", + }, + { + name: "unknown provider", + providerKey: "wxpay", + rawBody: "{}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey)) + }) + } +} + +func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) { + firstErr := errors.New("wrong provider") + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: firstErr, + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + notification: &payment.PaymentNotification{ + OrderID: "sub2_42", + TradeNo: "trade-42", + Status: payment.NotificationStatusSuccess, + }, + }, + } + + providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"}) + require.NoError(t, err) + require.Equal(t, payment.TypeWxpay, providerKey) + require.NotNil(t, notification) + require.Equal(t, "sub2_42", notification.OrderID) +} + +func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) { + providers := []payment.Provider{ + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed a"), + }, + webhookHandlerProviderStub{ + key: payment.TypeWxpay, + verifyErr: errors.New("verify failed b"), + }, + } + + _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil) + require.Error(t, err) +} + +type webhookHandlerProviderStub struct { + key string + notification *payment.PaymentNotification + verifyErr error +} + +func (p webhookHandlerProviderStub) Name() string { return p.key } +func (p webhookHandlerProviderStub) ProviderKey() string { return p.key } +func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.PaymentType(p.key)} +} +func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + if p.verifyErr != nil { + return nil, p.verifyErr + } + return p.notification, nil +} +func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1717b7a1e9afc666465d0dedcb3d457fe530a1a2..96964de41866ceaad06482318ccb13d06f638e5f 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { response.Success(c, dto.PublicSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, @@ -56,6 +57,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, + WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, + WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled, + WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, BackendModeEnabled: settings.BackendModeEnabled, @@ -65,5 +70,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + + ChannelMonitorEnabled: settings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, + + AvailableChannelsEnabled: settings.AvailableChannelsEnabled, }) } diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go new file mode 100644 index 0000000000000000000000000000000000000000..45d66f8e337ed5c4647518976dcbdbaf157a79a1 --- /dev/null +++ b/backend/internal/handler/setting_handler_public_test.go @@ -0,0 +1,122 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerPublicRepoStub struct { + values map[string]string +} + +func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.ForceEmailOnThirdPartySignup) +} + +func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { + gin.SetMode(gin.TestMode) + h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-mp-app", + service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectOpenEnabled: "true", + service.SettingKeyWeChatConnectMPEnabled: "true", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.WeChatOAuthEnabled) + require.True(t, resp.Data.WeChatOAuthOpenEnabled) + require.True(t, resp.Data.WeChatOAuthMPEnabled) +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 2535ea5e6f60e4da7761a0edcd4915a7cb143e22..f74c2b72265303154e8f1de59d043ae243d9f0c6 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -1,6 +1,9 @@ package handler import ( + "context" + "strings" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -12,14 +15,21 @@ import ( // UserHandler handles user-related requests type UserHandler struct { userService *service.UserService + authService *service.AuthService emailService *service.EmailService emailCache service.EmailCache } // NewUserHandler creates a new UserHandler -func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler { +func NewUserHandler( + userService *service.UserService, + authService *service.AuthService, + emailService *service.EmailService, + emailCache service.EmailCache, +) *UserHandler { return &UserHandler{ userService: userService, + authService: authService, emailService: emailService, emailCache: emailCache, } @@ -34,10 +44,33 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type userProfileResponse struct { + dto.User + AvatarURL string `json:"avatar_url,omitempty"` + AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"` + UsernameSource *userProfileSourceContext `json:"username_source,omitempty"` + DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"` + NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"` + ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"` + Identities service.UserIdentitySummarySet `json:"identities"` + AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"` + IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"` + EmailBound bool `json:"email_bound"` + LinuxDoBound bool `json:"linuxdo_bound"` + OIDCBound bool `json:"oidc_bound"` + WeChatBound bool `json:"wechat_bound"` +} + +type userProfileSourceContext struct { + Provider string `json:"provider,omitempty"` + Source string `json:"source,omitempty"` +} + // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { @@ -47,13 +80,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, dto.UserFromService(userData)) + response.Success(c, profileResp) } // ChangePassword handles changing user password @@ -101,6 +140,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { svcReq := service.UpdateProfileRequest{ Username: req.Username, + AvatarURL: req.AvatarURL, BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyThreshold: req.BalanceNotifyThreshold, } @@ -110,7 +150,155 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +type StartIdentityBindingRequest struct { + Provider string `json:"provider" binding:"required"` + RedirectTo string `json:"redirect_to"` +} + +type BindEmailIdentityRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code" binding:"required"` + Password string `json:"password" binding:"required"` +} + +type SendEmailBindingCodeRequest struct { + Email string `json:"email" binding:"required,email"` +} + +// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow. +// POST /api/v1/user/auth-identities/bind/start +func (h *UserHandler) StartIdentityBinding(c *gin.Context) { + if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req StartIdentityBindingRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{ + Provider: req.Provider, + RedirectTo: req.RedirectTo, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + +// BindEmailIdentity verifies and binds a local email identity for the current user. +// POST /api/v1/user/account-bindings/email +func (h *UserHandler) BindEmailIdentity(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req BindEmailIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + updatedUser, err := h.authService.BindEmailIdentity( + c.Request.Context(), + subject.UserID, + req.Email, + req.VerifyCode, + req.Password, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +// UnbindIdentity removes a third-party sign-in provider from the current user. +// DELETE /api/v1/user/account-bindings/:provider +func (h *UserHandler) UnbindIdentity(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult( + c.Request.Context(), + subject.UserID, + c.Param("provider"), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + if unbound && h.authService != nil { + if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil { + response.ErrorFrom(c, err) + return + } + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +// SendEmailBindingCode sends a verification code for the current user's email binding flow. +// POST /api/v1/user/account-bindings/email/send-code +func (h *UserHandler) SendEmailBindingCode(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req SendEmailBindingCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Verification code sent successfully"}) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -176,7 +364,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -212,7 +406,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -248,5 +448,117 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) { + identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user) + if err != nil { + return userProfileResponse{}, err + } + return userProfileResponseFromService(user, identities), nil +} + +func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse { + base := dto.UserFromService(user) + if base == nil { + return userProfileResponse{} + } + bindings := userProfileBindingMap(identities) + profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities) + return userProfileResponse{ + User: *base, + AvatarURL: user.AvatarURL, + AvatarSource: avatarSource, + UsernameSource: usernameSource, + DisplayNameSource: usernameSource, + NicknameSource: usernameSource, + ProfileSources: profileSources, + Identities: identities, + AuthBindings: bindings, + IdentityBindings: bindings, + EmailBound: identities.Email.Bound, + LinuxDoBound: identities.LinuxDo.Bound, + OIDCBound: identities.OIDC.Bound, + WeChatBound: identities.WeChat.Bound, + } +} + +func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary { + return map[string]service.UserIdentitySummary{ + "email": identities.Email, + "linuxdo": identities.LinuxDo, + "oidc": identities.OIDC, + "wechat": identities.WeChat, + } +} + +func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) ( + map[string]*userProfileSourceContext, + *userProfileSourceContext, + *userProfileSourceContext, +) { + if user == nil { + return nil, nil, nil + } + + thirdParty := thirdPartyIdentityProviders(identities) + var avatarSource *userProfileSourceContext + avatarValue := strings.TrimSpace(user.AvatarURL) + for _, summary := range thirdParty { + if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) { + avatarSource = buildUserProfileSourceContext(summary.Provider) + break + } + } + + usernameValue := strings.TrimSpace(user.Username) + var usernameSource *userProfileSourceContext + for _, summary := range thirdParty { + if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) { + usernameSource = buildUserProfileSourceContext(summary.Provider) + break + } + } + + profileSources := map[string]*userProfileSourceContext{} + if avatarSource != nil { + profileSources["avatar"] = avatarSource + } + if usernameSource != nil { + profileSources["username"] = usernameSource + profileSources["display_name"] = usernameSource + profileSources["nickname"] = usernameSource + } + if len(profileSources) == 0 { + return nil, avatarSource, usernameSource + } + return profileSources, avatarSource, usernameSource +} + +func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { + out := make([]service.UserIdentitySummary, 0, 3) + for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { + if summary.Bound { + out = append(out, summary) + } + } + return out +} + +func buildUserProfileSourceContext(provider string) *userProfileSourceContext { + provider = strings.TrimSpace(provider) + if provider == "" { + return nil + } + return &userProfileSourceContext{ + Provider: provider, + Source: provider, + } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a655b81cfd93c6e746a55c57e8861c01fc5dccca --- /dev/null +++ b/backend/internal/handler/user_handler_test.go @@ -0,0 +1,783 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userHandlerRepoStub struct { + user *service.User + identities []service.UserAuthIdentityRecord + unbound []string +} + +func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } +func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error { + cloned := *user + s.user = &cloned + return nil +} +func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + if s.user == nil || s.user.AvatarURL == "" { + return nil, nil + } + return &service.UserAvatar{ + StorageProvider: s.user.AvatarSource, + URL: s.user.AvatarURL, + ContentType: s.user.AvatarMIME, + ByteSize: s.user.AvatarByteSize, + SHA256: s.user.AvatarSHA256, + }, nil +} +func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + s.user.AvatarURL = input.URL + s.user.AvatarSource = input.StorageProvider + s.user.AvatarMIME = input.ContentType + s.user.AvatarByteSize = input.ByteSize + s.user.AvatarSHA256 = input.SHA256 + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error { + s.user.AvatarURL = "" + s.user.AvatarSource = "" + s.user.AvatarMIME = "" + s.user.AvatarByteSize = 0 + s.user.AvatarSHA256 = "" + return nil +} +func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} +func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error { + if s.user != nil { + s.user.LastActiveAt = &activeAt + } + return nil +} +func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) { + out := make([]service.UserAuthIdentityRecord, len(s.identities)) + copy(out, s.identities) + return out, nil +} +func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error { + s.unbound = append(s.unbound, provider) + filtered := s.identities[:0] + for _, identity := range s.identities { + if identity.ProviderType == provider { + continue + } + filtered = append(filtered, identity) + } + s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...) + return nil +} + +func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "handler-avatar@example.com", + Username: "handler-avatar", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.UpdateProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + AvatarURL string `json:"avatar_url"` + Username string `json:"username"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) + require.Equal(t, "handler-avatar", resp.Data.Username) +} + +func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-123456", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + { + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "oidc-user-abc", + Metadata: map[string]any{ + "suggested_display_name": "OIDC Display", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Identities struct { + Email struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + } `json:"email"` + LinuxDo struct { + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"linuxdo"` + OIDC struct { + Bound bool `json:"bound"` + DisplayName string `json:"display_name"` + ProviderKey string `json:"provider_key"` + } `json:"oidc"` + WeChat struct { + Bound bool `json:"bound"` + CanBind bool `json:"can_bind"` + BindStartPath string `json:"bind_start_path"` + } `json:"wechat"` + } `json:"identities"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.Identities.Email.Bound) + require.Equal(t, 1, resp.Data.Identities.Email.BoundCount) + require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName) + require.True(t, resp.Data.Identities.LinuxDo.Bound) + require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount) + require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName) + require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey) + require.True(t, resp.Data.Identities.OIDC.Bound) + require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName) + require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey) + require.False(t, resp.Data.Identities.WeChat.Bound) + require.True(t, resp.Data.Identities.WeChat.CanBind) + require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start") +} + +func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "legacy-profile@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, false, resp.Data["oidc_bound"]) + require.Equal(t, false, resp.Data["wechat_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + require.Equal(t, "linuxdo", avatarSource["source"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + require.Equal(t, "linuxdo", linuxdoBinding["provider"]) + + identityBindings, ok := resp.Data["identity_bindings"].(map[string]any) + require.True(t, ok) + emailBinding, ok := identityBindings["email"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, emailBinding["bound"]) + require.Equal(t, "profile.authBindings.notes.emailManagedFromProfile", emailBinding["note_key"]) + + linuxdoCompatBinding, ok := identityBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, "profile.authBindings.notes.canUnbind", linuxdoCompatBinding["note_key"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) + require.Equal(t, "linuxdo", usernameSource["source"]) +} + +func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 22, + Email: "edited-profile@example.com", + Username: "custom-name", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/custom.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-22", + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.NotContains(t, resp.Data, "avatar_source") + require.NotContains(t, resp.Data, "username_source") + require.NotContains(t, resp.Data, "profile_sources") +} + +type userHandlerEmailCacheStub struct { + data *service.VerificationCodeData +} + +type userHandlerRefreshTokenCacheStub struct { + revokedUserIDs []int64 +} + +func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error { + s.revokedUserIDs = append(s.revokedUserIDs, userID) + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + return s.data, nil +} + +func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain, + Username: "legacy-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "provider", Value: "email"}} + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Email string `json:"email"` + EmailBound bool `json:"email_bound"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "new@example.com", resp.Data.Email) + require.True(t, resp.Data.EmailBound) +} + +func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []string{"linuxdo"}, repo.unbound) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, false, linuxdoBinding["bound"]) +} + +func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 23, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-23", + }, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(5), repo.user.TokenVersion) +} + +func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 24, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "identity@example.com", + }, + }, + } + refreshTokenCache := &userHandlerRefreshTokenCacheStub{} + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24}) + c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}} + + handler.UnbindIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Empty(t, repo.unbound) + require.Empty(t, refreshTokenCache.revokedUserIDs) + require.Equal(t, int64(4), repo.user.TokenVersion) +} + +func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Email: "current@example.com", + Username: "bound-user", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, user.SetPassword("current-password")) + + repo := &userHandlerRepoStub{user: user} + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "PASSWORD_INCORRECT", resp.Reason) + require.Equal(t, "current password is incorrect", resp.Message) + require.Equal(t, "current@example.com", repo.user.Email) +} + +func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "identity@example.com", + Username: "identity-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + + body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.StartIdentityBinding(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "wechat", resp.Data.Provider) + require.Equal(t, "GET", resp.Data.Method) + require.True(t, resp.Data.UseBrowserRedirect) + require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start") + require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user") + require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile") +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 4b54d41ad32c28cc79582eb458459031da55c875..6d175488977bf15e2b45beaea4c17d8a45ebf1f7 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -34,35 +34,39 @@ func ProvideAdminHandlers( apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, channelHandler *admin.ChannelHandler, + channelMonitorHandler *admin.ChannelMonitorHandler, + channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, paymentHandler *admin.PaymentHandler, ) *AdminHandlers { return &AdminHandlers{ - Dashboard: dashboardHandler, - User: userHandler, - Group: groupHandler, - Account: accountHandler, - Announcement: announcementHandler, - DataManagement: dataManagementHandler, - Backup: backupHandler, - OAuth: oauthHandler, - OpenAIOAuth: openaiOAuthHandler, - GeminiOAuth: geminiOAuthHandler, - AntigravityOAuth: antigravityOAuthHandler, - Proxy: proxyHandler, - Redeem: redeemHandler, - Promo: promoHandler, - Setting: settingHandler, - Ops: opsHandler, - System: systemHandler, - Subscription: subscriptionHandler, - Usage: usageHandler, - UserAttribute: userAttributeHandler, - ErrorPassthrough: errorPassthroughHandler, - TLSFingerprintProfile: tlsFingerprintProfileHandler, - APIKey: apiKeyHandler, - ScheduledTest: scheduledTestHandler, - Channel: channelHandler, - Payment: paymentHandler, + Dashboard: dashboardHandler, + User: userHandler, + Group: groupHandler, + Account: accountHandler, + Announcement: announcementHandler, + DataManagement: dataManagementHandler, + Backup: backupHandler, + OAuth: oauthHandler, + OpenAIOAuth: openaiOAuthHandler, + GeminiOAuth: geminiOAuthHandler, + AntigravityOAuth: antigravityOAuthHandler, + Proxy: proxyHandler, + Redeem: redeemHandler, + Promo: promoHandler, + Setting: settingHandler, + Ops: opsHandler, + System: systemHandler, + Subscription: subscriptionHandler, + Usage: usageHandler, + UserAttribute: userAttributeHandler, + ErrorPassthrough: errorPassthroughHandler, + TLSFingerprintProfile: tlsFingerprintProfileHandler, + APIKey: apiKeyHandler, + ScheduledTest: scheduledTestHandler, + Channel: channelHandler, + ChannelMonitor: channelMonitorHandler, + ChannelMonitorTemplate: channelMonitorTemplateHandler, + Payment: paymentHandler, } } @@ -85,6 +89,7 @@ func ProvideHandlers( redeemHandler *RedeemHandler, subscriptionHandler *SubscriptionHandler, announcementHandler *AnnouncementHandler, + channelMonitorUserHandler *ChannelMonitorUserHandler, adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, @@ -92,24 +97,27 @@ func ProvideHandlers( totpHandler *TotpHandler, paymentHandler *PaymentHandler, paymentWebhookHandler *PaymentWebhookHandler, + availableChannelHandler *AvailableChannelHandler, _ *service.IdempotencyCoordinator, _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ - Auth: authHandler, - User: userHandler, - APIKey: apiKeyHandler, - Usage: usageHandler, - Redeem: redeemHandler, - Subscription: subscriptionHandler, - Announcement: announcementHandler, - Admin: adminHandlers, - Gateway: gatewayHandler, - OpenAIGateway: openaiGatewayHandler, - Setting: settingHandler, - Totp: totpHandler, - Payment: paymentHandler, - PaymentWebhook: paymentWebhookHandler, + Auth: authHandler, + User: userHandler, + APIKey: apiKeyHandler, + Usage: usageHandler, + Redeem: redeemHandler, + Subscription: subscriptionHandler, + Announcement: announcementHandler, + ChannelMonitor: channelMonitorUserHandler, + Admin: adminHandlers, + Gateway: gatewayHandler, + OpenAIGateway: openaiGatewayHandler, + Setting: settingHandler, + Totp: totpHandler, + Payment: paymentHandler, + PaymentWebhook: paymentWebhookHandler, + AvailableChannel: availableChannelHandler, } } @@ -123,12 +131,14 @@ var ProviderSet = wire.NewSet( NewRedeemHandler, NewSubscriptionHandler, NewAnnouncementHandler, + NewChannelMonitorUserHandler, NewGatewayHandler, NewOpenAIGatewayHandler, NewTotpHandler, ProvideSettingHandler, NewPaymentHandler, NewPaymentWebhookHandler, + NewAvailableChannelHandler, // Admin handlers admin.NewDashboardHandler, @@ -156,6 +166,8 @@ var ProviderSet = wire.NewSet( admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, admin.NewChannelHandler, + admin.NewChannelMonitorHandler, + admin.NewChannelMonitorRequestTemplateHandler, admin.NewPaymentHandler, // AdminHandlers and Handlers constructors diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index e39e957f4d29e2efd1dad9ce5e709f76764f0454..0581469d6549c6394cf5cf707793566d2ab52e5a 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -10,12 +10,20 @@ import ( "strings" ) +// AES256KeySize is the required key length (in bytes) for AES-256-GCM. +const AES256KeySize = 32 + // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function is kept only for seeding legacy ciphertext in tests and for +// the transitional Decrypt fallback. Scheduled for removal after all live +// deployments complete migration by re-saving their configs. func Encrypt(plaintext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } block, err := aes.NewCipher(key) @@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function remains only as a read-path fallback for pre-migration +// ciphertext records. Scheduled for removal once all deployments re-save +// their provider configs through the admin UI. func Decrypt(ciphertext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } parts := strings.SplitN(ciphertext, ":", 3) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index f0353173847966f3659ea8ff8c642d979d8f4f9b..41fd2c50cf52a59f8e04ffad6bb33d2e1936b4a4 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -45,11 +45,31 @@ type DefaultLoadBalancer struct { counter atomic.Uint64 } +type contextKey string + +const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id" + // NewDefaultLoadBalancer creates a new load balancer. func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer { return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey} } +func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context { + appID = strings.TrimSpace(appID) + if appID == "" { + return ctx + } + return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID) +} + +func wxpayJSAPIAppIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string) + return strings.TrimSpace(appID) +} + // instanceCandidate pairs an instance with its pre-fetched daily usage. type instanceCandidate struct { inst *dbent.PaymentProviderInstance @@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( } var matched []*dbent.PaymentProviderInstance + expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx) for _, inst := range instances { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". @@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( matched = append(matched, inst) } } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { + if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay { + config, cfgErr := lb.decryptConfig(inst.Config) + if cfgErr != nil { + slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr) + continue + } + if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID { + continue + } + } matched = append(matched, inst) } } @@ -231,6 +262,11 @@ func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType P if cl, ok := limits[lookupKey]; ok { return cl } + if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" { + if cl, ok := limits[aliasKey]; ok { + return cl + } + } return ChannelLimits{} } @@ -261,6 +297,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns if err != nil { return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err) } + if config == nil { + config = map[string]string{} + } if selected.PaymentMode != "" { config["paymentMode"] = selected.PaymentMode @@ -275,16 +314,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns }, nil } -func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) { - plaintext, err := Decrypt(encrypted, lb.encryptionKey) - if err != nil { - return nil, err +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext. +// Unreadable values (legacy ciphertext without a valid key, or malformed data) +// are treated as empty so the service keeps running while the admin re-enters +// the config via the UI. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a +// transitional compatibility shim for pre-plaintext records. Remove it (and +// the encryptionKey field + the Decrypt import) after a few releases once all +// live deployments have re-saved their provider configs through the UI. +func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { + return nil, nil } var config map[string]string - if err := json.Unmarshal([]byte(plaintext), &config); err != nil { - return nil, fmt.Errorf("unmarshal config: %w", err) + if err := json.Unmarshal([]byte(stored), &config); err == nil { + return config, nil + } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. + if len(lb.encryptionKey) == AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal + if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &config); err == nil { + return config, nil + } + } } - return config, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } // GetInstanceDailyAmount returns the total completed order amount for an instance today. @@ -321,14 +380,45 @@ func InstanceSupportsType(supportedTypes string, target PaymentType) bool { if supportedTypes == "" { return true } + normalizedTarget := normalizeVisibleMethodSupportType(target) for _, t := range strings.Split(supportedTypes, ",") { - if strings.TrimSpace(t) == target { + supported := strings.TrimSpace(t) + if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget { return true } } return false } +func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType { + switch strings.TrimSpace(paymentType) { + case TypeAlipay, TypeAlipayDirect: + return TypeAlipay + case TypeWxpay, TypeWxpayDirect: + return TypeWxpay + default: + return strings.TrimSpace(paymentType) + } +} + +func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType { + switch normalizeVisibleMethodSupportType(paymentType) { + case TypeAlipay: + return TypeAlipayDirect + case TypeWxpay: + return TypeWxpayDirect + default: + return "" + } +} + +func resolveWxpayJSAPIAppID(config map[string]string) string { + if appID := strings.TrimSpace(config["mpAppId"]); appID != "" { + return appID + } + return strings.TrimSpace(config["appId"]) +} + // GetInstanceConfig decrypts and returns the configuration for a provider instance by ID. func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID) diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 04b3c25b4745b798c3286c5b127634d04651b2e8..ed08a7dd49092bcf0d3938077d51e7bbb8c50f9d 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -68,10 +68,16 @@ func TestInstanceSupportsType(t *testing.T) { expected: true, }, { - name: "partial match should not succeed", + name: "legacy alipay direct supports canonical visible method", supportedTypes: "alipay_direct", target: "alipay", - expected: false, + expected: true, + }, + { + name: "legacy wxpay direct supports canonical visible method", + supportedTypes: "wxpay_direct", + target: "wxpay", + expected: true, }, { name: "empty supported types means all supported", @@ -92,6 +98,22 @@ func TestInstanceSupportsType(t *testing.T) { } } +func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) { + t.Parallel() + + inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66})) + got := getInstanceChannelLimits(inst, TypeAlipay) + if got.SingleMax != 66 { + t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got) + } + + wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8})) + wxGot := getInstanceChannelLimits(wxInst, TypeWxpay) + if wxGot.SingleMin != 8 { + t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot) + } +} + // --------------------------------------------------------------------------- // Helper to build test PaymentProviderInstance values // --------------------------------------------------------------------------- @@ -452,6 +474,103 @@ func TestStartOfDay(t *testing.T) { } } +func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) { + t.Parallel() + + key := make([]byte, AES256KeySize) + for i := range key { + key[i] = byte(i + 1) + } + wrongKey := make([]byte, AES256KeySize) + for i := range wrongKey { + wrongKey[i] = byte(0xFF - i) + } + + plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}` + + legacyEncrypted, err := Encrypt(plaintextJSON, key) + if err != nil { + t.Fatalf("seed Encrypt: %v", err) + } + + tests := []struct { + name string + stored string + key []byte + want map[string]string + }{ + { + name: "empty stored returns nil map", + stored: "", + key: key, + want: nil, + }, + { + name: "plaintext JSON parses directly", + stored: plaintextJSON, + key: nil, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "plaintext JSON works even with key present", + stored: plaintextJSON, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with correct key decrypts", + stored: legacyEncrypted, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with no key treated as empty", + stored: legacyEncrypted, + key: nil, + want: nil, + }, + { + name: "legacy ciphertext with wrong key treated as empty", + stored: legacyEncrypted, + key: wrongKey, + want: nil, + }, + { + name: "garbage data treated as empty", + stored: "not-json-and-not-ciphertext", + key: key, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + lb := NewDefaultLoadBalancer(nil, tt.key) + got, err := lb.decryptConfig(tt.stored) + if err != nil { + t.Fatalf("decryptConfig unexpected error: %v", err) + } + if !stringMapEqual(got, tt.want) { + t.Fatalf("decryptConfig = %v, want %v", got, tt.want) + } + }) + } +} + +// stringMapEqual compares two map[string]string values; nil and empty are equal. +func stringMapEqual(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if bv, ok := b[k]; !ok || bv != v { + return false + } + } + return true +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index af8a90c681a2897dd7b52f2c58e2fca9b4f7f747..1234b56819e8984b69c89cdae6b3217974670d70 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -15,8 +15,9 @@ import ( // Alipay product codes. const ( - alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" - alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePreCreate = "FACE_TO_FACE_PAYMENT" + alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" ) // Alipay response constants. @@ -26,6 +27,18 @@ const ( alipayRefundSuffix = "-refund" ) +var ( + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + return client.TradeWapPay(param) + } + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + return client.TradePreCreate(ctx, param) + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + return client.TradePagePay(param) + } +) + // Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK. type Alipay struct { instanceID string @@ -79,8 +92,24 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } -// CreatePayment creates an Alipay payment page URL. -func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { +func (a *Alipay) MerchantIdentityMetadata() map[string]string { + if a == nil { + return nil + } + appID := strings.TrimSpace(a.config["appId"]) + if appID == "" { + return nil + } + return map[string]string{"app_id": appID} +} + +// CreatePayment creates an Alipay payment using the following routing: +// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay. +// - Desktop: prefer alipay.trade.precreate to get a scan payload directly. +// - Desktop fallback: if precreate is unavailable for the merchant, fall back +// to alipay.trade.page.pay and expose both pay_url and qr_code so the +// frontend can render a QR while still allowing direct page open. +func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { return nil, err @@ -96,31 +125,73 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque } if req.IsMobile { - return a.createTrade(client, req, notifyURL, returnURL, true) + return a.createWapTrade(client, req, notifyURL, returnURL) } - return a.createTrade(client, req, notifyURL, returnURL, false) + return a.createDesktopTrade(ctx, client, req, notifyURL, returnURL) } -func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) { - if isMobile { - param := alipay.TradeWapPay{} - param.OutTradeNo = req.OrderID - param.TotalAmount = req.Amount - param.Subject = req.Subject - param.ProductCode = alipayProductCodeWapPay - param.NotifyURL = notifyURL - param.ReturnURL = returnURL - - payURL, err := client.TradeWapPay(param) - if err != nil { - return nil, fmt.Errorf("alipay TradeWapPay: %w", err) - } - return &payment.CreatePaymentResponse{ - TradeNo: req.OrderID, - PayURL: payURL.String(), - }, nil +func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradeWapPay{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodeWapPay + param.NotifyURL = notifyURL + param.ReturnURL = returnURL + + payURL, err := alipayTradeWapPay(client, param) + if err != nil { + return nil, fmt.Errorf("alipay TradeWapPay: %w", err) + } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + PayURL: payURL.String(), + }, nil +} + +func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL) + if precreateErr == nil { + return resp, nil + } + + resp, pagePayErr := a.createPagePayTrade(client, req, notifyURL, returnURL) + if pagePayErr == nil { + return resp, nil + } + + return nil, fmt.Errorf("alipay desktop payment failed: precreate=%v; pagepay=%w", precreateErr, pagePayErr) +} + +func (a *Alipay) createPrecreateTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradePreCreate{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodePreCreate + param.NotifyURL = notifyURL + + rsp, err := alipayTradePreCreate(ctx, client, param) + if err != nil { + return nil, fmt.Errorf("alipay TradePreCreate: %w", err) + } + if rsp == nil { + return nil, fmt.Errorf("alipay TradePreCreate: empty response") + } + if rsp.IsFailure() { + return nil, fmt.Errorf("alipay TradePreCreate failed: %s", rsp.Error.Error()) + } + if strings.TrimSpace(rsp.QRCode) == "" { + return nil, fmt.Errorf("alipay TradePreCreate: empty qr_code") } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + QRCode: rsp.QRCode, + }, nil +} + +func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { param := alipay.TradePagePay{} param.OutTradeNo = req.OrderID param.TotalAmount = req.Amount @@ -129,7 +200,7 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq param.NotifyURL = notifyURL param.ReturnURL = returnURL - payURL, err := client.TradePagePay(param) + payURL, err := alipayTradePagePay(client, param) if err != nil { return nil, fmt.Errorf("alipay TradePagePay: %w", err) } @@ -168,14 +239,23 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query amount, err := strconv.ParseFloat(result.TotalAmount, 64) if err != nil { - return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err) + amount, err = parseAlipayAmount( + result.TotalAmount, + result.ReceiptAmount, + result.BuyerPayAmount, + result.InvoiceAmount, + ) + if err != nil { + return nil, fmt.Errorf("alipay parse amount: %w", err) + } } return &payment.QueryOrderResponse{ - TradeNo: result.TradeNo, - Status: status, - Amount: amount, - PaidAt: result.SendPayDate, + TradeNo: result.TradeNo, + Status: status, + Amount: amount, + PaidAt: result.SendPayDate, + Metadata: a.MerchantIdentityMetadata(), }, nil } @@ -203,15 +283,31 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s amount, err := strconv.ParseFloat(notification.TotalAmount, 64) if err != nil { - return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err) + amount, err = parseAlipayAmount( + notification.TotalAmount, + notification.ReceiptAmount, + notification.BuyerPayAmount, + ) + if err != nil { + return nil, fmt.Errorf("alipay parse notification amount: %w", err) + } + } + + metadata := a.MerchantIdentityMetadata() + if appID := strings.TrimSpace(notification.AppId); appID != "" { + if metadata == nil { + metadata = map[string]string{} + } + metadata["app_id"] = appID } return &payment.PaymentNotification{ - TradeNo: notification.TradeNo, - OrderID: notification.OutTradeNo, - Amount: amount, - Status: status, - RawData: rawBody, + TradeNo: notification.TradeNo, + OrderID: notification.OutTradeNo, + Amount: amount, + Status: status, + RawData: rawBody, + Metadata: metadata, }, nil } @@ -272,8 +368,23 @@ func isTradeNotExist(err error) bool { return strings.Contains(err.Error(), alipayErrTradeNotExist) } +func parseAlipayAmount(values ...string) (float64, error) { + for _, raw := range values { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + amount, err := strconv.ParseFloat(raw, 64) + if err == nil { + return amount, nil + } + } + return 0, fmt.Errorf("no valid amount field") +} + // Ensure interface compliance. var ( - _ payment.Provider = (*Alipay)(nil) - _ payment.CancelableProvider = (*Alipay)(nil) + _ payment.Provider = (*Alipay)(nil) + _ payment.CancelableProvider = (*Alipay)(nil) + _ payment.MerchantIdentityProvider = (*Alipay)(nil) ) diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 7b0ce0d8c86223743bfe8a92f0094033914d6cb6..fdc8eec1ac546cde5a214151696b3bc6b871413d 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -3,9 +3,14 @@ package provider import ( + "context" "errors" + "net/url" "strings" "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/smartwalle/alipay/v3" ) func TestIsTradeNotExist(t *testing.T) { @@ -130,3 +135,173 @@ func TestNewAlipay(t *testing.T) { }) } } + +func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { + origPreCreate := alipayTradePreCreate + origPagePay := alipayTradePagePay + origWapPay := alipayTradeWapPay + t.Cleanup(func() { + alipayTradePreCreate = origPreCreate + alipayTradePagePay = origPagePay + alipayTradeWapPay = origWapPay + }) + + preCreateCalls := 0 + pagePayCalls := 0 + wapPayCalls := 0 + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + preCreateCalls++ + return nil, errors.New("merchant does not have FACE_TO_FACE_PAYMENT") + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ + if param.OutTradeNo != "sub2_100" { + t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100") + } + if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" { + t.Fatalf("notify_url = %q", param.NotifyURL) + } + return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") + } + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + wapPayCalls++ + return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay") + } + + provider := &Alipay{} + resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_100", + Amount: "88.00", + Subject: "Balance recharge", + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if preCreateCalls != 1 { + t.Fatalf("precreate calls = %d, want 1", preCreateCalls) + } + if pagePayCalls != 1 { + t.Fatalf("page pay calls = %d, want 1", pagePayCalls) + } + if wapPayCalls != 0 { + t.Fatalf("wap pay calls = %d, want 0", wapPayCalls) + } + if resp.PayURL == "" { + t.Fatal("expected pay_url for desktop page pay") + } + if resp.QRCode != resp.PayURL { + t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL) + } +} + +func TestCreateTradeUsesWapPayForMobile(t *testing.T) { + origWapPay := alipayTradeWapPay + t.Cleanup(func() { + alipayTradeWapPay = origWapPay + }) + + wapPayCalls := 0 + alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { + wapPayCalls++ + if param.ReturnURL != "https://merchant.example.com/payment/result" { + t.Fatalf("return_url = %q", param.ReturnURL) + } + return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay") + } + + provider := &Alipay{} + resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_101", + Amount: "18.00", + Subject: "Balance recharge", + IsMobile: true, + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wapPayCalls != 1 { + t.Fatalf("wap pay calls = %d, want 1", wapPayCalls) + } + if resp.PayURL == "" { + t.Fatal("expected pay_url for mobile wap pay") + } +} + +func TestCreateTradeUsesPrecreateForDesktopWhenAvailable(t *testing.T) { + origPreCreate := alipayTradePreCreate + origPagePay := alipayTradePagePay + t.Cleanup(func() { + alipayTradePreCreate = origPreCreate + alipayTradePagePay = origPagePay + }) + + preCreateCalls := 0 + pagePayCalls := 0 + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + preCreateCalls++ + if param.ProductCode != alipayProductCodePreCreate { + t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePreCreate) + } + return &alipay.TradePreCreateRsp{ + Error: alipay.Error{Code: alipay.CodeSuccess}, + QRCode: "https://qr.alipay.example.com/precreate-token", + }, nil + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ + return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") + } + + provider := &Alipay{} + resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_102", + Amount: "66.00", + Subject: "Balance recharge", + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if preCreateCalls != 1 { + t.Fatalf("precreate calls = %d, want 1", preCreateCalls) + } + if pagePayCalls != 0 { + t.Fatalf("page pay calls = %d, want 0", pagePayCalls) + } + if resp.QRCode != "https://qr.alipay.example.com/precreate-token" { + t.Fatalf("qr_code = %q", resp.QRCode) + } + if resp.PayURL != "" { + t.Fatalf("pay_url = %q, want empty for precreate", resp.PayURL) + } +} + +func TestAlipayMerchantIdentityMetadata(t *testing.T) { + t.Parallel() + + provider := &Alipay{ + config: map[string]string{ + "appId": "2021001234567890", + }, + } + + metadata := provider.MerchantIdentityMetadata() + if metadata["app_id"] != "2021001234567890" { + t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890") + } +} + +func TestParseAlipayAmount(t *testing.T) { + t.Parallel() + + amount, err := parseAlipayAmount("", "88.00", "77.00") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if amount != 88 { + t.Fatalf("amount = %v, want 88", amount) + } + + if _, err := parseAlipayAmount("", "not-a-number"); err == nil { + t.Fatal("expected error when no valid amount field exists") + } +} diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index e33a567d0398b240a26144f3f1dd44836748d23d..37bd38b27fc39bfeb055c739b49b399fb32ed9c0 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -59,6 +59,17 @@ func (e *EasyPay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay} } +func (e *EasyPay) MerchantIdentityMetadata() map[string]string { + if e == nil { + return nil + } + pid := strings.TrimSpace(e.config["pid"]) + if pid == "" { + return nil + } + return map[string]string{"pid": pid} +} + func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { // Payment mode determined by instance config, not payment type. // "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php). @@ -178,7 +189,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer status = payment.ProviderStatusPaid } amount, _ := strconv.ParseFloat(resp.Money, 64) - return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil + return &payment.QueryOrderResponse{ + TradeNo: tradeNo, + Status: status, + Amount: amount, + Metadata: e.MerchantIdentityMetadata(), + }, nil } func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) { @@ -203,9 +219,17 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st status = payment.ProviderStatusSuccess } amount, _ := strconv.ParseFloat(params["money"], 64) + + metadata := e.MerchantIdentityMetadata() + if pid := strings.TrimSpace(params["pid"]); pid != "" { + if metadata == nil { + metadata = map[string]string{} + } + metadata["pid"] = pid + } return &payment.PaymentNotification{ TradeNo: params["trade_no"], OrderID: params["out_trade_no"], - Amount: amount, Status: status, RawData: rawBody, + Amount: amount, Status: status, RawData: rawBody, Metadata: metadata, }, nil } diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go index 146a6fa1afd7aea5649cd9d110edcce86dbbdbb0..8328d294e88bae3bdcfb08ef9c267705953bdba0 100644 --- a/backend/internal/payment/provider/easypay_sign_test.go +++ b/backend/internal/payment/provider/easypay_sign_test.go @@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) { t.Fatal("easyPayVerifySign should return false for an incorrect sign value") } } + +func TestEasyPayMerchantIdentityMetadata(t *testing.T) { + t.Parallel() + + provider := &EasyPay{ + config: map[string]string{ + "pid": "1001", + }, + } + + metadata := provider.MerchantIdentityMetadata() + if metadata["pid"] != "1001" { + t.Fatalf("pid = %q, want %q", metadata["pid"], "1001") + } +} diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 0b41c4fb6f348dd1bfde6198b47f367a2ea880a9..e6291dd31dff1bdd73ca7605ed802c65c6e21472 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -3,22 +3,24 @@ package provider import ( "bytes" "context" - "crypto/rsa" "fmt" "io" - "log/slog" "net/http" + "net/url" + "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/wechatpay-apiv3/wechatpay-go/core" "github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers" "github.com/wechatpay-apiv3/wechatpay-go/core/notify" "github.com/wechatpay-apiv3/wechatpay-go/core/option" "github.com/wechatpay-apiv3/wechatpay-go/services/payments" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" "github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic" "github.com/wechatpay-apiv3/wechatpay-go/utils" @@ -26,8 +28,23 @@ import ( // WeChat Pay constants. const ( - wxpayCurrency = "CNY" - wxpayH5Type = "Wap" + wxpayCurrency = "CNY" + wxpayH5Type = "Wap" + wxpayResultPath = "/payment/result" +) + +const ( + wxpayMetadataAppID = "appid" + wxpayMetadataMerchantID = "mchid" + wxpayMetadataCurrency = "currency" + wxpayMetadataTradeState = "trade_state" +) + +// WeChat Pay create-payment modes. +const ( + wxpayModeNative = "native" + wxpayModeH5 = "h5" + wxpayModeJSAPI = "jsapi" ) // WeChat Pay trade states. @@ -43,9 +60,16 @@ const ( wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS" ) -// WeChat Pay error codes. -const ( - wxpayErrNoAuth = "NO_AUTH" +var ( + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + return svc.Prepay(ctx, req) + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + return svc.Prepay(ctx, req) + } + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + return svc.PrepayWithRequestPayment(ctx, req) + } ) type Wxpay struct { @@ -56,15 +80,35 @@ type Wxpay struct { notifyHandler *notify.Handler } +const wxpayAPIv3KeyLength = 32 + func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) { - required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"} + // All fields are required. Platform-certificate mode is intentionally unsupported — + // WeChat has been migrating all merchants to the pubkey verifier since 2024-10, + // and newly-provisioned merchants cannot download platform certificates at all. + required := []string{"appId", "mchId", "privateKey", "apiV3Key", "certSerial", "publicKey", "publicKeyId"} for _, k := range required { if config[k] == "" { - return nil, fmt.Errorf("wxpay config missing required key: %s", k) + return nil, infraerrors.BadRequest("WXPAY_CONFIG_MISSING_KEY", "missing_required_key"). + WithMetadata(map[string]string{"key": k}) } } - if len(config["apiV3Key"]) != 32 { - return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"])) + if len(config["apiV3Key"]) != wxpayAPIv3KeyLength { + return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY_LENGTH", "invalid_key_length"). + WithMetadata(map[string]string{ + "key": "apiV3Key", + "expected": strconv.Itoa(wxpayAPIv3KeyLength), + "actual": strconv.Itoa(len(config["apiV3Key"])), + }) + } + // Parse PEMs eagerly so malformed keys surface at save time, not at order creation. + if _, err := utils.LoadPrivateKey(formatPEM(config["privateKey"], "PRIVATE KEY")); err != nil { + return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key"). + WithMetadata(map[string]string{"key": "privateKey"}) + } + if _, err := utils.LoadPublicKey(formatPEM(config["publicKey"], "PUBLIC KEY")); err != nil { + return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key"). + WithMetadata(map[string]string{"key": "publicKey"}) } return &Wxpay{instanceID: instanceID, config: config}, nil } @@ -75,6 +119,16 @@ func (w *Wxpay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeWxpay} } +// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a +// given provider config. A dedicated MP AppID takes precedence over the base +// merchant AppID. +func ResolveWxpayJSAPIAppID(config map[string]string) string { + if appID := strings.TrimSpace(config["mpAppId"]); appID != "" { + return appID + } + return strings.TrimSpace(config["appId"]) +} + func formatPEM(key, keyType string) string { key = strings.TrimSpace(key) if strings.HasPrefix(key, "-----BEGIN") { @@ -89,14 +143,19 @@ func (w *Wxpay) ensureClient() (*core.Client, error) { if w.coreClient != nil { return w.coreClient, nil } - privateKey, publicKey, err := w.loadKeyPair() + privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY")) if err != nil { - return nil, err + return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key"). + WithMetadata(map[string]string{"key": "privateKey"}) + } + publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY")) + if err != nil { + return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key"). + WithMetadata(map[string]string{"key": "publicKey"}) } - certSerial := w.config["certSerial"] verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey) client, err := core.NewClient(context.Background(), - option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey), + option.WithMerchantCredential(w.config["mchId"], w.config["certSerial"], privateKey), option.WithVerifier(verifier)) if err != nil { return nil, fmt.Errorf("wxpay init client: %w", err) @@ -110,18 +169,6 @@ func (w *Wxpay) ensureClient() (*core.Client, error) { return w.coreClient, nil } -func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) { - privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY")) - if err != nil { - return nil, nil, fmt.Errorf("wxpay load private key: %w", err) - } - publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY")) - if err != nil { - return nil, nil, fmt.Errorf("wxpay load public key: %w", err) - } - return privateKey, publicKey, nil -} - func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := w.ensureClient() if err != nil { @@ -139,30 +186,61 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ if err != nil { return nil, fmt.Errorf("wxpay create payment: %w", err) } - if req.IsMobile && req.ClientIP != "" { - resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true) - if err == nil { - return resp, nil - } - if !strings.Contains(err.Error(), wxpayErrNoAuth) { - return nil, err - } - slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID) + + mode, err := resolveWxpayCreateMode(req) + if err != nil { + return nil, err + } + switch mode { + case wxpayModeJSAPI: + return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen) + case wxpayModeH5: + return w.prepayH5(ctx, client, req, notifyURL, totalFen) + case wxpayModeNative: + return w.prepayNative(ctx, client, req, notifyURL, totalFen) + default: + return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode) } - return w.createOrder(ctx, client, req, notifyURL, totalFen, false) } -func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) { - if useH5 { - return w.prepayH5(ctx, c, req, notifyURL, totalFen) +func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { + svc := jsapi.JsapiApiService{Client: c} + cur := wxpayCurrency + appID := ResolveWxpayJSAPIAppID(w.config) + prepayReq := jsapi.PrepayRequest{ + Appid: core.String(appID), + Mchid: core.String(w.config["mchId"]), + Description: core.String(req.Subject), + OutTradeNo: core.String(req.OrderID), + NotifyUrl: core.String(notifyURL), + Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur}, + Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))}, + } + if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" { + prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)} } - return w.prepayNative(ctx, c, req, notifyURL, totalFen) + resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq) + if err != nil { + return nil, fmt.Errorf("wxpay jsapi prepay: %w", err) + } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + ResultType: payment.CreatePaymentResultJSAPIReady, + JSAPI: &payment.WechatJSAPIPayload{ + AppID: wxSV(resp.Appid), + TimeStamp: wxSV(resp.TimeStamp), + NonceStr: wxSV(resp.NonceStr), + Package: wxSV(resp.Package), + SignType: wxSV(resp.SignType), + PaySign: wxSV(resp.PaySign), + }, + }, nil } func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { svc := native.NativeApiService{Client: c} cur := wxpayCurrency - resp, _, err := svc.Prepay(ctx, native.PrepayRequest{ + resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]), Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID), NotifyUrl: core.String(notifyURL), @@ -181,13 +259,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) { svc := h5.H5ApiService{Client: c} cur := wxpayCurrency - tp := wxpayH5Type - resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{ + resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]), Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID), NotifyUrl: core.String(notifyURL), Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur}, - SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}}, + SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)}, }) if err != nil { return nil, fmt.Errorf("wxpay h5 prepay: %w", err) @@ -196,9 +273,77 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create if resp.H5Url != nil { h5URL = *resp.H5Url } + h5URL, err = appendWxpayRedirectURL(h5URL, req) + if err != nil { + return nil, err + } return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil } +func buildWxpayH5Info(config map[string]string) *h5.H5Info { + tp := wxpayH5Type + info := &h5.H5Info{Type: &tp} + if appName := strings.TrimSpace(config["h5AppName"]); appName != "" { + info.AppName = core.String(appName) + } + if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" { + info.AppUrl = core.String(appURL) + } + return info +} + +func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) { + if strings.TrimSpace(req.OpenID) != "" { + return wxpayModeJSAPI, nil + } + if req.IsMobile { + if strings.TrimSpace(req.ClientIP) == "" { + return "", fmt.Errorf("wxpay H5 payment requires client IP") + } + return wxpayModeH5, nil + } + return wxpayModeNative, nil +} + +func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) { + h5URL = strings.TrimSpace(h5URL) + returnURL := strings.TrimSpace(req.ReturnURL) + if h5URL == "" || returnURL == "" { + return h5URL, nil + } + + redirectURL, err := buildWxpayResultURL(returnURL, req) + if err != nil { + return "", err + } + + sep := "&" + if !strings.Contains(h5URL, "?") { + sep = "?" + } + return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil +} + +func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) { + u, err := url.Parse(returnURL) + if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") { + return "", fmt.Errorf("return URL must be an absolute http(s) URL") + } + + values := u.Query() + values.Set("out_trade_no", strings.TrimSpace(req.OrderID)) + if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" { + values.Set("payment_type", paymentType) + } + if strings.TrimSpace(u.Path) == "" { + u.Path = wxpayResultPath + } + u.RawPath = "" + u.RawQuery = values.Encode() + u.Fragment = "" + return u.String(), nil +} + func wxSV(s *string) string { if s == nil { return "" @@ -219,6 +364,32 @@ func mapWxState(s string) string { } } +func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string { + if tx == nil { + return nil + } + + metadata := map[string]string{} + if appID := wxSV(tx.Appid); appID != "" { + metadata[wxpayMetadataAppID] = appID + } + if merchantID := wxSV(tx.Mchid); merchantID != "" { + metadata[wxpayMetadataMerchantID] = merchantID + } + if tradeState := wxSV(tx.TradeState); tradeState != "" { + metadata[wxpayMetadataTradeState] = tradeState + } + if tx.Amount != nil { + if currency := wxSV(tx.Amount.Currency); currency != "" { + metadata[wxpayMetadataCurrency] = currency + } + } + if len(metadata) == 0 { + return nil + } + return metadata +} + func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { c, err := w.ensureClient() if err != nil { @@ -243,7 +414,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO if tx.SuccessTime != nil { pa = *tx.SuccessTime } - return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil + return &payment.QueryOrderResponse{ + TradeNo: id, + Status: mapWxState(wxSV(tx.TradeState)), + Amount: amt, + PaidAt: pa, + Metadata: buildWxpayTransactionMetadata(tx), + }, nil } func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { @@ -275,7 +452,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers } return &payment.PaymentNotification{ TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo), - Amount: amt, Status: st, RawData: rawBody, + Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx), }, nil } diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go index b8b99537457f18701efdec12fc7cf7b06626b328..e8ac5e547be20c4286aec693cf11e4cfc64e6ccf 100644 --- a/backend/internal/payment/provider/wxpay_test.go +++ b/backend/internal/payment/provider/wxpay_test.go @@ -3,12 +3,44 @@ package provider import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "net/url" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/wechatpay-apiv3/wechatpay-go/core" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" + "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" ) +// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings. +// The wechatpay-go SDK expects PKCS8 private keys and PKIX public keys. +func generateTestKeyPair(t *testing.T) (privPEM, pubPEM string) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate rsa key: %v", err) + } + privDER, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatalf("marshal pkcs8: %v", err) + } + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + t.Fatalf("marshal pkix: %v", err) + } + return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})) +} + func TestMapWxState(t *testing.T) { t.Parallel() @@ -96,6 +128,33 @@ func TestWxSV(t *testing.T) { } } +func TestBuildWxpayTransactionMetadata(t *testing.T) { + t.Parallel() + + tx := &payments.Transaction{ + Appid: strPtr("wx-app-id"), + Mchid: strPtr("mch-id"), + TradeState: strPtr(wxpayTradeStateSuccess), + Amount: &payments.TransactionAmount{ + Currency: strPtr(wxpayCurrency), + }, + } + + metadata := buildWxpayTransactionMetadata(tx) + if metadata[wxpayMetadataAppID] != "wx-app-id" { + t.Fatalf("appid = %q", metadata[wxpayMetadataAppID]) + } + if metadata[wxpayMetadataMerchantID] != "mch-id" { + t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID]) + } + if metadata[wxpayMetadataCurrency] != wxpayCurrency { + t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency]) + } + if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess { + t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState]) + } +} + func strPtr(s string) *string { return &s } @@ -149,13 +208,14 @@ func TestFormatPEM(t *testing.T) { func TestNewWxpay(t *testing.T) { t.Parallel() + privPEM, pubPEM := generateTestKeyPair(t) validConfig := map[string]string{ "appId": "wx1234567890", "mchId": "1234567890", - "privateKey": "fake-private-key", + "privateKey": privPEM, "apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes - "publicKey": "fake-public-key", - "publicKeyId": "key-id-001", + "publicKey": pubPEM, + "publicKeyId": "PUB_KEY_ID_TEST", "certSerial": "SERIAL001", } @@ -206,6 +266,12 @@ func TestNewWxpay(t *testing.T) { wantErr: true, errSubstr: "apiV3Key", }, + { + name: "missing certSerial", + config: withOverride(map[string]string{"certSerial": ""}), + wantErr: true, + errSubstr: "certSerial", + }, { name: "missing publicKey", config: withOverride(map[string]string{"publicKey": ""}), @@ -218,17 +284,29 @@ func TestNewWxpay(t *testing.T) { wantErr: true, errSubstr: "publicKeyId", }, + { + name: "malformed privateKey PEM", + config: withOverride(map[string]string{"privateKey": "not-a-valid-pem"}), + wantErr: true, + errSubstr: "WXPAY_CONFIG_INVALID_KEY", + }, + { + name: "malformed publicKey PEM", + config: withOverride(map[string]string{"publicKey": "not-a-valid-pem"}), + wantErr: true, + errSubstr: "WXPAY_CONFIG_INVALID_KEY", + }, { name: "apiV3Key too short", config: withOverride(map[string]string{"apiV3Key": "short"}), wantErr: true, - errSubstr: "exactly 32 bytes", + errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH", }, { name: "apiV3Key too long", config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes wantErr: true, - errSubstr: "exactly 32 bytes", + errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH", }, } @@ -257,3 +335,375 @@ func TestNewWxpay(t *testing.T) { }) } } + +func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) { + t.Parallel() + + resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{ + OrderID: "sub2_42", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("buildWxpayResultURL returned error: %v", err) + } + + parsed, err := url.Parse(resultURL) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + query := parsed.Query() + if parsed.Path != wxpayResultPath { + t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath) + } + if query.Get("resume_token") != "resume-42" { + t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42") + } + if query.Get("order_id") != "42" { + t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42") + } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42") + } +} + +func TestResolveWxpayJSAPIAppID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config map[string]string + want string + }{ + { + name: "prefers dedicated mp app id", + config: map[string]string{ + "mpAppId": "wx-mp-app", + "appId": "wx-merchant-app", + }, + want: "wx-mp-app", + }, + { + name: "falls back to merchant app id", + config: map[string]string{ + "appId": "wx-merchant-app", + }, + want: "wx-merchant-app", + }, + { + name: "missing app ids returns empty", + config: map[string]string{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want { + t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveWxpayCreateMode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req payment.CreatePaymentRequest + wantMode string + wantErr string + }{ + { + name: "desktop uses native", + req: payment.CreatePaymentRequest{}, + wantMode: wxpayModeNative, + }, + { + name: "mobile uses h5 when client ip is present", + req: payment.CreatePaymentRequest{ + IsMobile: true, + ClientIP: "203.0.113.10", + }, + wantMode: wxpayModeH5, + }, + { + name: "mobile without client ip returns clear error", + req: payment.CreatePaymentRequest{ + IsMobile: true, + }, + wantErr: "requires client IP", + }, + { + name: "openid uses jsapi mode", + req: payment.CreatePaymentRequest{ + OpenID: "openid-123", + }, + wantMode: wxpayModeJSAPI, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := resolveWxpayCreateMode(tt.req) + if tt.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantMode { + t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode) + } + }) + } +} + +func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + if got := wxSV(req.Payer.Openid); got != "openid-123" { + t.Fatalf("openid = %q, want %q", got, "openid-123") + } + if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" { + t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10") + } + return &jsapi.PrepayWithRequestPaymentResponse{ + Appid: core.String("wx123"), + TimeStamp: core.String("1712345678"), + NonceStr: core.String("nonce-123"), + Package: core.String("prepay_id=wx_prepay_123"), + SignType: core.String("RSA"), + PaySign: core.String("signed-payload"), + }, nil, nil + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + return &h5.PrepayResponse{}, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_88", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + NotifyURL: "https://merchant.example/payment/notify", + OpenID: "openid-123", + ClientIP: "203.0.113.10", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if jsapiCalls != 1 { + t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls) + } + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) + } + if h5Calls != 0 { + t.Fatalf("h5 prepay calls = %d, want 0", h5Calls) + } + if resp.ResultType != payment.CreatePaymentResultJSAPIReady { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady) + } + if resp.JSAPI == nil { + t.Fatal("expected jsapi payload, got nil") + } + if resp.JSAPI.AppID != "wx123" { + t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123") + } + if resp.JSAPI.TimeStamp != "1712345678" { + t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678") + } + if resp.JSAPI.NonceStr != "nonce-123" { + t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123") + } + if resp.JSAPI.Package != "prepay_id=wx_prepay_123" { + t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123") + } + if resp.JSAPI.SignType != "RSA" { + t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA") + } + if resp.JSAPI.PaySign != "signed-payload" { + t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload") + } +} + +func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + if req.SceneInfo == nil { + t.Fatal("expected scene_info, got nil") + } + if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" { + t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10") + } + if req.SceneInfo.H5Info == nil { + t.Fatal("expected scene_info.h5_info, got nil") + } + if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type { + t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type) + } + if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" { + t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API") + } + if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" { + t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com") + } + return &h5.PrepayResponse{ + H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"), + }, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + "h5AppName": "Sub2API", + "h5AppUrl": "https://app.example.com", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_99", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + Subject: "Balance Recharge", + NotifyURL: "https://merchant.example/payment/notify", + ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99", + ClientIP: "203.0.113.10", + IsMobile: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if jsapiCalls != 0 { + t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) + } + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) + } + if h5Calls != 1 { + t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) + } + if !strings.Contains(resp.PayURL, "redirect_url=") { + t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL) + } +} + +func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) { + origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment + origNativePrepay := wxpayNativePrepay + origH5Prepay := wxpayH5Prepay + t.Cleanup(func() { + wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay + wxpayNativePrepay = origNativePrepay + wxpayH5Prepay = origH5Prepay + }) + + jsapiCalls := 0 + nativeCalls := 0 + h5Calls := 0 + wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { + jsapiCalls++ + return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil + } + wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { + h5Calls++ + return nil, nil, errors.New("NO_AUTH") + } + wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { + nativeCalls++ + return &native.PrepayResponse{ + CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"), + }, nil, nil + } + + provider := &Wxpay{ + config: map[string]string{ + "appId": "wx123", + "mchId": "mch123", + }, + coreClient: &core.Client{}, + } + + resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ + OrderID: "sub2_100", + Amount: "66.88", + PaymentType: payment.TypeWxpay, + Subject: "Balance Recharge", + NotifyURL: "https://merchant.example/payment/notify", + ClientIP: "203.0.113.10", + IsMobile: true, + }) + if err == nil { + t.Fatal("expected no-auth error, got nil") + } + if jsapiCalls != 0 { + t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) + } + if h5Calls != 1 { + t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) + } + if nativeCalls != 0 { + t.Fatalf("native prepay calls = %d, want 0", nativeCalls) + } + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } + if !strings.Contains(err.Error(), "NO_AUTH") { + t.Fatalf("error = %v, want NO_AUTH", err) + } +} diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go index 5d613a4a9d746480afbd4bbd583fc778a27ba617..e7ac6727b95c2cb9d56fb61fe400380751f68c52 100644 --- a/backend/internal/payment/types.go +++ b/backend/internal/payment/types.go @@ -101,34 +101,69 @@ type CreatePaymentRequest struct { Subject string // Product description NotifyURL string // Webhook callback URL ReturnURL string // Browser redirect URL after payment + OpenID string // WeChat JSAPI payer OpenID when available ClientIP string // Payer's IP address IsMobile bool // Whether the request comes from a mobile device InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe) } +// CreatePaymentResultType describes the shape of the create-payment result. +type CreatePaymentResultType = string + +const ( + CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created" + CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required" + CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready" +) + +// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment. +type WechatOAuthInfo struct { + AuthorizeURL string `json:"authorize_url,omitempty"` + AppID string `json:"appid,omitempty"` + OpenID string `json:"openid,omitempty"` + Scope string `json:"scope,omitempty"` + State string `json:"state,omitempty"` + RedirectURL string `json:"redirect_url,omitempty"` +} + +// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment. +type WechatJSAPIPayload struct { + AppID string `json:"appId,omitempty"` + TimeStamp string `json:"timeStamp,omitempty"` + NonceStr string `json:"nonceStr,omitempty"` + Package string `json:"package,omitempty"` + SignType string `json:"signType,omitempty"` + PaySign string `json:"paySign,omitempty"` +} + // CreatePaymentResponse is returned after successfully initiating a payment. type CreatePaymentResponse struct { - TradeNo string // Third-party transaction ID - PayURL string // H5 payment URL (alipay/wxpay) - QRCode string // QR code content for scanning - ClientSecret string // Stripe PaymentIntent client secret + TradeNo string // Third-party transaction ID + PayURL string // H5 payment URL (alipay/wxpay) + QRCode string // QR code content for scanning + ClientSecret string // Stripe PaymentIntent client secret + ResultType CreatePaymentResultType // Typed result contract for frontend flows + OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required + JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready } // QueryOrderResponse describes the payment status from the upstream provider. type QueryOrderResponse struct { - TradeNo string - Status string // "pending", "paid", "failed", "refunded" - Amount float64 // Amount in CNY - PaidAt string // RFC3339 timestamp or empty + TradeNo string + Status string // "pending", "paid", "failed", "refunded" + Amount float64 // Amount in CNY + PaidAt string // RFC3339 timestamp or empty + Metadata map[string]string } // PaymentNotification is the parsed result of a webhook/notify callback. type PaymentNotification struct { - TradeNo string - OrderID string - Amount float64 - Status string // "success" or "failed" - RawData string // Raw notification body for audit + TradeNo string + OrderID string + Amount float64 + Status string // "success" or "failed" + RawData string // Raw notification body for audit + Metadata map[string]string } // RefundRequest contains the parameters for requesting a refund. @@ -179,3 +214,9 @@ type CancelableProvider interface { // CancelPayment cancels/expires a pending payment on the upstream platform. CancelPayment(ctx context.Context, tradeNo string) error } + +// MerchantIdentityProvider exposes the current non-sensitive merchant identity +// derived from provider configuration for snapshot consistency checks. +type MerchantIdentityProvider interface { + MerchantIdentityMetadata() map[string]string +} diff --git a/backend/internal/payment/wire.go b/backend/internal/payment/wire.go index 9717465d8e87f55ab98b1bd1862e8566c4dd3211..4b7f422dec06881734f22ef60e146f1c5c7905db 100644 --- a/backend/internal/payment/wire.go +++ b/backend/internal/payment/wire.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "fmt" "log/slog" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" @@ -19,11 +20,22 @@ type EncryptionKey []byte // When the key is non-empty but invalid (bad hex or wrong length), an error is returned // to prevent startup with a misconfigured encryption key. func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) { - if cfg.Totp.EncryptionKey == "" { + if cfg == nil { + slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable") + return nil, nil + } + keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey) + if keyHex == "" { slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable") return nil, nil } - key, err := hex.DecodeString(cfg.Totp.EncryptionKey) + // Reject auto-generated TOTP keys for payment signing. + // They change across restarts/instances and can silently break resume-token flows. + if !cfg.Totp.EncryptionKeyConfigured { + slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens") + return nil, nil + } + key, err := hex.DecodeString(keyHex) if err != nil { return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err) } diff --git a/backend/internal/payment/wire_test.go b/backend/internal/payment/wire_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1b360f89f737ff467516da1b44a1d919412e5850 --- /dev/null +++ b/backend/internal/payment/wire_test.go @@ -0,0 +1,62 @@ +package payment + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: strings.Repeat("a", 64), + EncryptionKeyConfigured: false, + }, + } + + key, err := ProvideEncryptionKey(cfg) + if err != nil { + t.Fatalf("ProvideEncryptionKey returned error: %v", err) + } + if len(key) != 0 { + t.Fatalf("encryption key len = %d, want 0", len(key)) + } +} + +func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + EncryptionKeyConfigured: true, + }, + } + + key, err := ProvideEncryptionKey(cfg) + if err != nil { + t.Fatalf("ProvideEncryptionKey returned error: %v", err) + } + if len(key) != 32 { + t.Fatalf("encryption key len = %d, want 32", len(key)) + } +} + +func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Totp: config.TotpConfig{ + EncryptionKey: "abcd", + EncryptionKeyConfigured: true, + }, + } + + _, err := ProvideEncryptionKey(cfg) + if err == nil { + t.Fatal("expected error for invalid key length") + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 49e38bf8154fbc891bfd51c2ee5aded90ee663f0..be9f3aae78942832cca0790f36faca43e1aa9f1f 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,18 +15,15 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ + {ID: "gpt-5.5", Object: "model", Created: 1776873600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.5"}, {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"}, - {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, - {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, - {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, - {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"}, - {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"}, - {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"}, - {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"}, + {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"}, + {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"}, + {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"}, } // DefaultModelIDs returns the default model ID list @@ -39,7 +36,7 @@ func DefaultModelIDs() []string { } // DefaultTestModel default model for testing OpenAI accounts -const DefaultTestModel = "gpt-5.1-codex" +const DefaultTestModel = "gpt-5.4" // DefaultInstructions default instructions for non-Codex CLI requests // Content loaded from instructions.txt at compile time diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 24115c33d74bbdde478b0d28e9feefb93f9dbe58..78f739ac205ab53c13113769625b0df55f97c177 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { return err } + if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil { + return err + } if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil { return err } diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go index 2dc346b15544ef3f8a886bbec91acf145a3d8894..5268ec45ff610cdcf37fce9e42f54b6f223a150f 100644 --- a/backend/internal/repository/announcement_read_repo.go +++ b/backend/internal/repository/announcement_read_repo.go @@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error { client := clientFromContext(ctx, r.client) - return client.AnnouncementRead.Create(). + err := client.AnnouncementRead.Create(). SetAnnouncementID(announcementID). SetUserID(userID). SetReadAt(readAt). OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID). DoNothing(). Exec(ctx) + if isSQLNoRowsError(err) { + return nil + } + return err } func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 38ea9bde3d73e704e014d793dd1bb366a277c5cc..3a52740512ddd800279e7b0046f2b5a00045ccec 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -149,6 +149,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyExtraEmails, user.FieldTotalRecharged, + user.FieldSignupSource, + user.FieldLastLoginAt, + user.FieldLastActiveAt, + user.FieldRpmLimit, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -175,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, group.FieldMessagesDispatchModelConfig, + group.FieldRpmLimit, ) }). Only(ctx) @@ -656,6 +661,9 @@ func userEntityToService(u *dbent.User) *service.User { Balance: u.Balance, Concurrency: u.Concurrency, Status: u.Status, + SignupSource: u.SignupSource, + LastLoginAt: u.LastLoginAt, + LastActiveAt: u.LastActiveAt, TotpSecretEncrypted: u.TotpSecretEncrypted, TotpEnabled: u.TotpEnabled, TotpEnabledAt: u.TotpEnabledAt, @@ -663,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User { BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, BalanceNotifyThreshold: u.BalanceNotifyThreshold, TotalRecharged: u.TotalRecharged, + RPMLimit: u.RpmLimit, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, } @@ -707,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + RPMLimit: g.RpmLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7e34777ad8a7fa45fd3b6aabb40c9e613d27cfb5 --- /dev/null +++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go @@ -0,0 +1,80 @@ +//go:build integration + +package repository + +import ( + "context" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql") + migration108SQL, err := os.ReadFile(migration108Path) + require.NoError(t, err) + + migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql") + migration108aSQL, err := os.ReadFile(migration108aPath) + require.NoError(t, err) + + migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql") + migration109SQL, err := os.ReadFile(migration109Path) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE; +DROP TABLE IF EXISTS auth_identity_channels CASCADE; +DROP TABLE IF EXISTS identity_adoption_decisions CASCADE; +DROP TABLE IF EXISTS pending_auth_sessions CASCADE; +DROP TABLE IF EXISTS auth_identities CASCADE; + +ALTER TABLE users + DROP COLUMN IF EXISTS signup_source, + DROP COLUMN IF EXISTS last_login_at, + DROP COLUMN IF EXISTS last_active_at; +`) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration108SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration108aSQL)) + require.NoError(t, err) + + var userID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&userID)) + + _, err = tx.ExecContext(ctx, string(migration109SQL)) + require.NoError(t, err) + + var reportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery' + AND report_key = $1 +`, strconv.FormatInt(userID, 10)).Scan(&reportCount)) + require.Equal(t, 1, reportCount) + + var reportTypeLimit int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT character_maximum_length +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' +`).Scan(&reportTypeLimit)) + require.GreaterOrEqual(t, reportTypeLimit, 45) + + require.NotZero(t, userID) +} diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e64934c531421ecfd1e372bab4511d4f3f9b4704 --- /dev/null +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -0,0 +1,959 @@ +//go:build integration + +package repository + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoUserID)) + + var wechatUnionUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatUnionUserID)) + + var wechatOpenIDOnlyUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatOpenIDOnlyUserID)) + + var syntheticAuthIdentityID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb) +RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID)) + + var linuxDoLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}') +RETURNING id +`, linuxDoUserID).Scan(&linuxDoLegacyID)) + + var wechatUnionLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}') +RETURNING id +`, wechatUnionUserID).Scan(&wechatUnionLegacyID)) + + var wechatOpenIDLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}') +RETURNING id +`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxDoCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-user-1' +`, linuxDoUserID).Scan(&linuxDoCount)) + require.Equal(t, 1, linuxDoCount) + + var wechatSubject string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT provider_subject +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-1' +`, wechatUnionUserID).Scan(&wechatSubject)) + require.Equal(t, "union-1", wechatSubject) + + var wechatChannelCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_channels channel +JOIN auth_identities ai ON ai.id = channel.identity_id +WHERE ai.user_id = $1 + AND channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = 'oa' + AND channel.channel_app_id = 'wx-app-1' + AND channel.channel_subject = 'openid-union-1' +`, wechatUnionUserID).Scan(&wechatChannelCount)) + require.Equal(t, 1, wechatChannelCount) + + var legacyOpenIDOnlyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount)) + require.Equal(t, 1, legacyOpenIDOnlyReportCount) + + var syntheticReviewCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount)) + require.Equal(t, 1, syntheticReviewCount) + + var unionLegacyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount)) + require.Zero(t, unionLegacyReportCount) + require.NotZero(t, linuxDoLegacyID) +} + +func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + var beforeCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +`).Scan(&beforeCount)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var afterCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports + `).Scan(&afterCount)) + require.Equal(t, beforeCount, afterCount) +} + +func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoMalformedUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoMalformedUserID)) + + var linuxDoArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoArrayUserID)) + + var wechatUnionArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatUnionArrayUserID)) + + var wechatOpenIDArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatOpenIDArrayUserID)) + + var linuxDoMalformedLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid') +RETURNING id +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID)) + + var linuxDoArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]') +RETURNING id +`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID)) + + var wechatUnionArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]') +RETURNING id +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID)) + + var wechatOpenIDArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]') +RETURNING id +`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var linuxDoMalformedMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-malformed' +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType)) + require.Equal(t, "object", linuxDoMalformedMetadataType) + + var linuxDoArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-array' +`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType)) + require.Equal(t, "object", linuxDoArrayMetadataType) + + var wechatUnionArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-array' +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType)) + require.Equal(t, "object", wechatUnionArrayMetadataType) + + var invalidJSONReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType)) + require.Equal(t, "object", invalidJSONReportDetailsType) + + var openIDOnlyReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType)) + require.Equal(t, "object", openIDOnlyReportDetailsType) + + var preservedArrayMetadataCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE id IN ( + SELECT id + FROM auth_identities + WHERE (user_id = $1 AND provider_subject = 'linuxdo-array') + OR (user_id = $2 AND provider_subject = 'union-array') +) + AND metadata ? '_legacy_metadata_raw_json' +`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount)) + require.Equal(t, 2, preservedArrayMetadataCount) + + require.NotZero(t, linuxDoArrayLegacyID) + require.NotZero(t, wechatUnionArrayLegacyID) +} + +func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + userIDs := make([]int64, 0, 8) + for _, email := range []string{ + "linuxdo-conflict-legacy@example.com", + "linuxdo-conflict-owner@example.com", + "wechat-conflict-legacy@example.com", + "wechat-conflict-owner@example.com", + "wechat-channel-legacy@example.com", + "wechat-channel-owner@example.com", + "linuxdo-invalid-json@example.com", + "wechat-openid-invalid-json@example.com", + } { + var userID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ($1, 'hash', 'user', 'active', 0, 1) +RETURNING id`, email).Scan(&userID)) + userIDs = append(userIDs, userID) + } + + linuxdoConflictLegacyUserID := userIDs[0] + linuxdoConflictOwnerUserID := userIDs[1] + wechatConflictLegacyUserID := userIDs[2] + wechatConflictOwnerUserID := userIDs[3] + wechatChannelLegacyUserID := userIDs[4] + wechatChannelOwnerUserID := userIDs[5] + linuxdoInvalidJSONUserID := userIDs[6] + wechatInvalidOpenIDUserID := userIDs[7] + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb) +RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb) +RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64))) + + var wechatChannelOwnerIdentityID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata) +VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb) +RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID)) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb) +RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64))) + + var linuxdoConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}') +RETURNING id +`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID)) + + var wechatConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}') +RETURNING id +`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID)) + + var wechatChannelConflictLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}') +RETURNING id +`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID)) + + var linuxdoInvalidJSONLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid') +RETURNING id +`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID)) + + var wechatInvalidOpenIDLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid') +RETURNING id +`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxdoConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount)) + require.Equal(t, 1, linuxdoConflictReportCount) + + var wechatConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount)) + require.Equal(t, 1, wechatConflictReportCount) + + var channelConflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_channel_conflict' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount)) + require.Equal(t, 1, channelConflictReportCount) + + var invalidJSONReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key IN ($1, $2) +`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount)) + require.Equal(t, 2, invalidJSONReportCount) + + var linuxdoInvalidIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-invalid-json' +`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount)) + require.Equal(t, 1, linuxdoInvalidIdentityCount) + + var wechatOpenIDOnlyReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount)) + require.Equal(t, 1, wechatOpenIDOnlyReportCount) +} + +func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + var beforeCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +`).Scan(&beforeCount)) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var afterCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports + `).Scan(&afterCount)) + require.Equal(t, beforeCount, afterCount) +} + +func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migrationSQL, err := os.ReadFile(migrationPath) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}') +RETURNING id +`, wechatFirstUserID).Scan(new(int64))) + + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}') +RETURNING id +`, wechatSecondUserID).Scan(new(int64))) + + _, err = tx.ExecContext(ctx, string(migrationSQL)) + require.NoError(t, err) + + var linuxDoIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-ambiguous-subject' +`).Scan(&linuxDoIdentityCount)) + require.Zero(t, linuxDoIdentityCount) + + var wechatIdentityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-ambiguous-subject' +`).Scan(&wechatIdentityCount)) + require.Zero(t, wechatIdentityCount) + + var wechatChannelCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_channels +WHERE provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND channel = 'oa' + AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b') +`).Scan(&wechatChannelCount)) + require.Zero(t, wechatChannelCount) +} + +func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + var linuxDoFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoFirstUserID)) + + var linuxDoSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoSecondUserID)) + + var wechatFirstUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatFirstUserID)) + + var wechatSecondUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatSecondUserID)) + + var linuxDoFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}') +RETURNING id +`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID)) + + var linuxDoSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}') +RETURNING id +`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID)) + + var wechatFirstLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}') +RETURNING id +`, wechatFirstUserID).Scan(&wechatFirstLegacyID)) + + var wechatSecondLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}') +RETURNING id +`, wechatSecondUserID).Scan(&wechatSecondLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var identityCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject') + OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject') +`).Scan(&identityCount)) + require.Zero(t, identityCount) + + var conflictReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount)) + require.Equal(t, 4, conflictReportCount) + + var winnerAttributedReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_conflict' + AND report_key IN ($1, $2, $3, $4) + AND details ->> 'existing_identity_id' IS NOT NULL +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount)) + require.Zero(t, winnerAttributedReportCount) +} + +func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql") + migration108aSQL, err := os.ReadFile(migration108aPath) + require.NoError(t, err) + + migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql") + migration109SQL, err := os.ReadFile(migration109Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + prepareLegacyExternalIdentitiesTable(t, tx, ctx) + truncateAuthIdentityLegacyFixtureTables(t, tx, ctx) + + _, err = tx.ExecContext(ctx, ` +ALTER TABLE auth_identity_migration_reports +ALTER COLUMN report_type TYPE VARCHAR(40); +`) + require.NoError(t, err) + + var oidcSyntheticUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&oidcSyntheticUserID)) + + var linuxdoLegacyUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxdoLegacyUserID)) + + var invalidMetadataLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid') +RETURNING id +`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration108aSQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration109SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var reportTypeWidth int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT character_maximum_length +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' +`).Scan(&reportTypeWidth)) + require.Equal(t, 80, reportTypeWidth) + + var oidcSyntheticRecoveryReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery' + AND report_key = $1 +`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount)) + require.Equal(t, 1, oidcSyntheticRecoveryReportCount) + + var invalidMetadataReportCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount)) + require.Equal(t, 1, invalidMetadataReportCount) +} + +func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) { + t.Helper() + + _, err := tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); +`) + require.NoError(t, err) +} + +func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) { + t.Helper() + + _, err := tx.ExecContext(ctx, ` +TRUNCATE TABLE + auth_identity_channels, + identity_adoption_decisions, + pending_auth_sessions, + auth_identities, + auth_identity_migration_reports, + user_provider_default_grants, + user_avatars, + user_external_identities, + users +RESTART IDENTITY CASCADE; +`) + require.NoError(t, err) +} diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go new file mode 100644 index 0000000000000000000000000000000000000000..800ee43b2ee341734fdb924b536b07d49c29d80d --- /dev/null +++ b/backend/internal/repository/channel_monitor_repo.go @@ -0,0 +1,755 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// channelMonitorRepository 实现 service.ChannelMonitorRepository。 +// +// 选型说明: +// - CRUD 走 ent,复用项目的事务上下文支持 +// - 聚合查询(latest per model / availability)走原生 SQL,避免 ent 在 GROUP BY 上 +// 的样板代码,并保证索引能被命中 +type channelMonitorRepository struct { + client *dbent.Client + db *sql.DB +} + +// NewChannelMonitorRepository 创建仓储实例。 +func NewChannelMonitorRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRepository { + return &channelMonitorRepository{client: client, db: db} +} + +// ---------- CRUD ---------- + +func (r *channelMonitorRepository) Create(ctx context.Context, m *service.ChannelMonitor) error { + client := clientFromContext(ctx, r.client) + builder := client.ChannelMonitor.Create(). + SetName(m.Name). + SetProvider(channelmonitor.Provider(m.Provider)). + SetEndpoint(m.Endpoint). + SetAPIKeyEncrypted(m.APIKey). // 调用方传入的已是密文 + SetPrimaryModel(m.PrimaryModel). + SetExtraModels(emptySliceIfNil(m.ExtraModels)). + SetGroupName(m.GroupName). + SetEnabled(m.Enabled). + SetIntervalSeconds(m.IntervalSeconds). + SetCreatedBy(m.CreatedBy). + SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)). + SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode)) + if m.TemplateID != nil { + builder = builder.SetTemplateID(*m.TemplateID) + } + if m.BodyOverride != nil { + builder = builder.SetBodyOverride(m.BodyOverride) + } + + created, err := builder.Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil) + } + m.ID = created.ID + m.CreatedAt = created.CreatedAt + m.UpdatedAt = created.UpdatedAt + return nil +} + +func (r *channelMonitorRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitor, error) { + row, err := r.client.ChannelMonitor.Query(). + Where(channelmonitor.IDEQ(id)). + Only(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil) + } + return entToServiceMonitor(row), nil +} + +func (r *channelMonitorRepository) Update(ctx context.Context, m *service.ChannelMonitor) error { + client := clientFromContext(ctx, r.client) + updater := client.ChannelMonitor.UpdateOneID(m.ID). + SetName(m.Name). + SetProvider(channelmonitor.Provider(m.Provider)). + SetEndpoint(m.Endpoint). + SetAPIKeyEncrypted(m.APIKey). + SetPrimaryModel(m.PrimaryModel). + SetExtraModels(emptySliceIfNil(m.ExtraModels)). + SetGroupName(m.GroupName). + SetEnabled(m.Enabled). + SetIntervalSeconds(m.IntervalSeconds). + SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)). + SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode)) + if m.TemplateID != nil { + updater = updater.SetTemplateID(*m.TemplateID) + } else { + updater = updater.ClearTemplateID() + } + if m.BodyOverride != nil { + updater = updater.SetBodyOverride(m.BodyOverride) + } else { + updater = updater.ClearBodyOverride() + } + + updated, err := updater.Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil) + } + m.UpdatedAt = updated.UpdatedAt + return nil +} + +func (r *channelMonitorRepository) Delete(ctx context.Context, id int64) error { + client := clientFromContext(ctx, r.client) + if err := client.ChannelMonitor.DeleteOneID(id).Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil) + } + return nil +} + +func (r *channelMonitorRepository) List(ctx context.Context, params service.ChannelMonitorListParams) ([]*service.ChannelMonitor, int64, error) { + q := r.client.ChannelMonitor.Query() + if params.Provider != "" { + q = q.Where(channelmonitor.ProviderEQ(channelmonitor.Provider(params.Provider))) + } + if params.Enabled != nil { + q = q.Where(channelmonitor.EnabledEQ(*params.Enabled)) + } + if s := strings.TrimSpace(params.Search); s != "" { + q = q.Where(channelmonitor.Or( + channelmonitor.NameContainsFold(s), + channelmonitor.GroupNameContainsFold(s), + channelmonitor.PrimaryModelContainsFold(s), + )) + } + + total, err := q.Count(ctx) + if err != nil { + return nil, 0, fmt.Errorf("count monitors: %w", err) + } + + pageSize := params.PageSize + if pageSize <= 0 { + pageSize = 20 + } + page := params.Page + if page <= 0 { + page = 1 + } + + rows, err := q. + Order(dbent.Desc(channelmonitor.FieldID)). + Offset((page - 1) * pageSize). + Limit(pageSize). + All(ctx) + if err != nil { + return nil, 0, fmt.Errorf("list monitors: %w", err) + } + + out := make([]*service.ChannelMonitor, 0, len(rows)) + for _, row := range rows { + out = append(out, entToServiceMonitor(row)) + } + return out, int64(total), nil +} + +// ---------- 调度器辅助 ---------- + +func (r *channelMonitorRepository) ListEnabled(ctx context.Context) ([]*service.ChannelMonitor, error) { + rows, err := r.client.ChannelMonitor.Query(). + Where(channelmonitor.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list enabled monitors: %w", err) + } + out := make([]*service.ChannelMonitor, 0, len(rows)) + for _, row := range rows { + out = append(out, entToServiceMonitor(row)) + } + return out, nil +} + +func (r *channelMonitorRepository) MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error { + client := clientFromContext(ctx, r.client) + if err := client.ChannelMonitor.UpdateOneID(id). + SetLastCheckedAt(checkedAt). + Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil) + } + return nil +} + +func (r *channelMonitorRepository) InsertHistoryBatch(ctx context.Context, rows []*service.ChannelMonitorHistoryRow) error { + if len(rows) == 0 { + return nil + } + client := clientFromContext(ctx, r.client) + bulk := make([]*dbent.ChannelMonitorHistoryCreate, 0, len(rows)) + for _, row := range rows { + c := client.ChannelMonitorHistory.Create(). + SetMonitorID(row.MonitorID). + SetModel(row.Model). + SetStatus(channelmonitorhistory.Status(row.Status)). + SetMessage(row.Message). + SetCheckedAt(row.CheckedAt) + if row.LatencyMs != nil { + c = c.SetLatencyMs(*row.LatencyMs) + } + if row.PingLatencyMs != nil { + c = c.SetPingLatencyMs(*row.PingLatencyMs) + } + bulk = append(bulk, c) + } + if _, err := client.ChannelMonitorHistory.CreateBulk(bulk...).Save(ctx); err != nil { + return fmt.Errorf("insert history bulk: %w", err) + } + return nil +} + +// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批, +// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。 +func (r *channelMonitorRepository) DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) { + return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneHistorySQL, before) +} + +// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。 +// model 为空时不过滤;非空时只返回该模型的记录。 +func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*service.ChannelMonitorHistoryEntry, error) { + q := r.client.ChannelMonitorHistory.Query(). + Where(channelmonitorhistory.MonitorIDEQ(monitorID)) + if strings.TrimSpace(model) != "" { + q = q.Where(channelmonitorhistory.ModelEQ(model)) + } + rows, err := q. + Order(dbent.Desc(channelmonitorhistory.FieldCheckedAt)). + Limit(limit). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list history: %w", err) + } + out := make([]*service.ChannelMonitorHistoryEntry, 0, len(rows)) + for _, row := range rows { + entry := &service.ChannelMonitorHistoryEntry{ + ID: row.ID, + Model: row.Model, + Status: string(row.Status), + LatencyMs: row.LatencyMs, + PingLatencyMs: row.PingLatencyMs, + Message: row.Message, + CheckedAt: row.CheckedAt, + } + out = append(out, entry) + } + return out, nil +} + +// ---------- 用户视图聚合(原生 SQL) ---------- + +// ListLatestPerModel 用 DISTINCT ON 取每个 (monitor_id, model) 的最近一条记录。 +// 借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。 +func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) { + const q = ` + SELECT DISTINCT ON (model) + model, status, latency_ms, ping_latency_ms, checked_at + FROM channel_monitor_histories + WHERE monitor_id = $1 + ORDER BY model, checked_at DESC + ` + rows, err := r.db.QueryContext(ctx, q, monitorID) + if err != nil { + return nil, fmt.Errorf("query latest per model: %w", err) + } + defer func() { _ = rows.Close() }() + + out := make([]*service.ChannelMonitorLatest, 0) + for rows.Next() { + l := &service.ChannelMonitorLatest{} + var latency, ping sql.NullInt64 + if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil { + return nil, fmt.Errorf("scan latest row: %w", err) + } + assignNullInt(&l.LatencyMs, latency) + assignNullInt(&l.PingLatencyMs, ping) + out = append(out, l) + } + return out, rows.Err() +} + +// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。 +// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。 +func assignNullInt(dst **int, n sql.NullInt64) { + if !n.Valid { + return + } + v := int(n.Int64) + *dst = &v +} + +// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。 +// "可用" = status IN (operational, degraded)。 +// +// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。 +// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories, +// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。 +func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) { + if windowDays <= 0 { + windowDays = 7 + } + const q = ` + SELECT model, + COUNT(*) AS total, + COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok, + CASE WHEN COUNT(latency_ms) > 0 + THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms) + ELSE NULL END AS avg_latency_ms + FROM channel_monitor_histories + WHERE monitor_id = $1 + AND checked_at >= NOW() - ($2::int || ' days')::interval + GROUP BY model + ` + rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays) + if err != nil { + return nil, fmt.Errorf("query availability: %w", err) + } + defer func() { _ = rows.Close() }() + + out := make([]*service.ChannelMonitorAvailability, 0) + for rows.Next() { + row, err := scanAvailabilityRow(rows, windowDays) + if err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +// scanAvailabilityRow 把单行 (model, total, ok, avg_latency) 扫描为 ChannelMonitorAvailability。 +// 仅服务于 ComputeAvailability(4 列);批量版本因为多一列 monitor_id 直接 inline 调 finalizeAvailabilityRow。 +func scanAvailabilityRow(rows interface{ Scan(...any) error }, windowDays int) (*service.ChannelMonitorAvailability, error) { + row := &service.ChannelMonitorAvailability{WindowDays: windowDays} + var avgLatency sql.NullFloat64 + if err := rows.Scan(&row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil { + return nil, fmt.Errorf("scan availability row: %w", err) + } + finalizeAvailabilityRow(row, avgLatency) + return row, nil +} + +// finalizeAvailabilityRow 根据 OperationalChecks/TotalChecks 算出可用率, +// 并把 sql.NullFloat64 的平均延迟解包为 *int。两处复用避免维护漂移。 +func finalizeAvailabilityRow(row *service.ChannelMonitorAvailability, avgLatency sql.NullFloat64) { + if row.TotalChecks > 0 { + row.AvailabilityPct = float64(row.OperationalChecks) * 100.0 / float64(row.TotalChecks) + } + if avgLatency.Valid { + v := int(avgLatency.Float64) + row.AvgLatencyMs = &v + } +} + +// ListLatestForMonitorIDs 一次性查询多个监控的"每个 (monitor_id, model) 最近一条"记录。 +// 利用 PG 的 DISTINCT ON 特性,借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。 +func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*service.ChannelMonitorLatest, error) { + out := make(map[int64][]*service.ChannelMonitorLatest, len(ids)) + if len(ids) == 0 { + return out, nil + } + const q = ` + SELECT DISTINCT ON (monitor_id, model) + monitor_id, model, status, latency_ms, ping_latency_ms, checked_at + FROM channel_monitor_histories + WHERE monitor_id = ANY($1) + ORDER BY monitor_id, model, checked_at DESC + ` + rows, err := r.db.QueryContext(ctx, q, pq.Array(ids)) + if err != nil { + return nil, fmt.Errorf("query latest batch: %w", err) + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var monitorID int64 + l := &service.ChannelMonitorLatest{} + var latency, ping sql.NullInt64 + if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil { + return nil, fmt.Errorf("scan latest batch row: %w", err) + } + assignNullInt(&l.LatencyMs, latency) + assignNullInt(&l.PingLatencyMs, ping) + out[monitorID] = append(out[monitorID], l) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。 +// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。 +// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单, +// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。 +// +// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。 +// 空 ids / 空 primaryModels 返回空 map,不报错。 +func (r *channelMonitorRepository) ListRecentHistoryForMonitors( + ctx context.Context, + ids []int64, + primaryModels map[int64]string, + perMonitorLimit int, +) (map[int64][]*service.ChannelMonitorHistoryEntry, error) { + out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids)) + pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels) + if len(pairIDs) == 0 { + return out, nil + } + perMonitorLimit = clampTimelineLimit(perMonitorLimit) + + const q = ` + WITH targets AS ( + SELECT unnest($1::bigint[]) AS monitor_id, + unnest($2::text[]) AS model + ), + ranked AS ( + SELECT h.monitor_id, + h.status, + h.latency_ms, + h.ping_latency_ms, + h.checked_at, + ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn + FROM channel_monitor_histories h + JOIN targets t + ON t.monitor_id = h.monitor_id AND t.model = h.model + ) + SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at + FROM ranked + WHERE rn <= $3 + ORDER BY monitor_id, checked_at DESC + ` + rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit) + if err != nil { + return nil, fmt.Errorf("query recent history batch: %w", err) + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var monitorID int64 + entry := &service.ChannelMonitorHistoryEntry{} + var latency, ping sql.NullInt64 + if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil { + return nil, fmt.Errorf("scan recent history row: %w", err) + } + assignNullInt(&entry.LatencyMs, latency) + assignNullInt(&entry.PingLatencyMs, ping) + out[monitorID] = append(out[monitorID], entry) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。 +// 保证两个数组长度一致且一一对应,供 unnest 展开。 +func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) { + if len(ids) == 0 || len(primaryModels) == 0 { + return nil, nil + } + pairIDs := make([]int64, 0, len(ids)) + pairModels := make([]string, 0, len(ids)) + for _, id := range ids { + model, ok := primaryModels[id] + if !ok || strings.TrimSpace(model) == "" { + continue + } + pairIDs = append(pairIDs, id) + pairModels = append(pairModels, model) + } + return pairIDs, pairModels +} + +// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。 +// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。 +const ( + timelineLimitMin = 1 + timelineLimitMax = 200 +) + +// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。 +func clampTimelineLimit(n int) int { + if n < timelineLimitMin { + return timelineLimitMin + } + if n > timelineLimitMax { + return timelineLimitMax + } + return n +} + +// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。 +// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。 +func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) { + out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids)) + if len(ids) == 0 { + return out, nil + } + if windowDays <= 0 { + windowDays = 7 + } + const q = ` + SELECT monitor_id, + model, + COUNT(*) AS total, + COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok, + CASE WHEN COUNT(latency_ms) > 0 + THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms) + ELSE NULL END AS avg_latency_ms + FROM channel_monitor_histories + WHERE monitor_id = ANY($1) + AND checked_at >= NOW() - ($2::int || ' days')::interval + GROUP BY monitor_id, model + ` + rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays) + if err != nil { + return nil, fmt.Errorf("query availability batch: %w", err) + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var monitorID int64 + row := &service.ChannelMonitorAvailability{WindowDays: windowDays} + var avgLatency sql.NullFloat64 + if err := rows.Scan(&monitorID, &row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil { + return nil, fmt.Errorf("scan availability batch row: %w", err) + } + // 批量查询多了首列 monitor_id;其余字段的可用率/平均延迟换算与单 monitor 版本一致, + // 抽出 finalizeAvailabilityRow 复用,避免两处分别维护除法与 NullFloat 解包。 + finalizeAvailabilityRow(row, avgLatency) + out[monitorID] = append(out[monitorID], row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil +} + +// ---------- 聚合维护 ---------- + +// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细 +// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。 +// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填, +// 重复执行只会用最新统计覆盖; +// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。 +func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) { + const q = ` + INSERT INTO channel_monitor_daily_rollups ( + monitor_id, model, bucket_date, + total_checks, ok_count, + operational_count, degraded_count, failed_count, error_count, + sum_latency_ms, count_latency, + sum_ping_latency_ms, count_ping_latency, + computed_at + ) + SELECT + monitor_id, + model, + $1::date AS bucket_date, + COUNT(*) AS total_checks, + COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count, + COUNT(*) FILTER (WHERE status = 'operational') AS operational_count, + COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count, + COUNT(*) FILTER (WHERE status = 'failed') AS failed_count, + COUNT(*) FILTER (WHERE status = 'error') AS error_count, + COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms, + COUNT(latency_ms) AS count_latency, + COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms, + COUNT(ping_latency_ms) AS count_ping_latency, + NOW() + FROM channel_monitor_histories + WHERE checked_at >= $1::date + AND checked_at < ($1::date + INTERVAL '1 day') + GROUP BY monitor_id, model + ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET + total_checks = EXCLUDED.total_checks, + ok_count = EXCLUDED.ok_count, + operational_count = EXCLUDED.operational_count, + degraded_count = EXCLUDED.degraded_count, + failed_count = EXCLUDED.failed_count, + error_count = EXCLUDED.error_count, + sum_latency_ms = EXCLUDED.sum_latency_ms, + count_latency = EXCLUDED.count_latency, + sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms, + count_ping_latency = EXCLUDED.count_ping_latency, + computed_at = NOW() + ` + res, err := r.db.ExecContext(ctx, q, targetDate) + if err != nil { + return 0, fmt.Errorf("upsert daily rollups for %s: %w", targetDate.Format("2006-01-02"), err) + } + n, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("rows affected (upsert rollups): %w", err) + } + return n, nil +} + +// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。 +func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) { + return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneRollupSQL, beforeDate) +} + +// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000, +// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。 +const channelMonitorPruneBatchSize = 5000 + +// channelMonitorPruneHistorySQL 分批物理删明细表过期行。 +const channelMonitorPruneHistorySQL = ` +WITH batch AS ( + SELECT id FROM channel_monitor_histories + WHERE checked_at < $1 + ORDER BY id + LIMIT $2 +) +DELETE FROM channel_monitor_histories +WHERE id IN (SELECT id FROM batch) +` + +// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型 +// 保证与 DATE 列一致比较。 +const channelMonitorPruneRollupSQL = ` +WITH batch AS ( + SELECT id FROM channel_monitor_daily_rollups + WHERE bucket_date < $1::date + ORDER BY id + LIMIT $2 +) +DELETE FROM channel_monitor_daily_rollups +WHERE id IN (SELECT id FROM batch) +` + +// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。 +// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。 +func deleteChannelMonitorBatched(ctx context.Context, db *sql.DB, query string, cutoff time.Time) (int64, error) { + var total int64 + for { + res, err := db.ExecContext(ctx, query, cutoff, channelMonitorPruneBatchSize) + if err != nil { + return total, fmt.Errorf("channel_monitor prune batch: %w", err) + } + affected, err := res.RowsAffected() + if err != nil { + return total, fmt.Errorf("channel_monitor prune rows affected: %w", err) + } + total += affected + if affected == 0 { + break + } + } + return total, nil +} + +// LoadAggregationWatermark 读 watermark 表(id=1)。 +// watermark 表不是 ent schema(只有一行),直接走原生 SQL。 +// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略 +func (r *channelMonitorRepository) LoadAggregationWatermark(ctx context.Context) (*time.Time, error) { + const q = `SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1` + var t sql.NullTime + if err := r.db.QueryRowContext(ctx, q).Scan(&t); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("load aggregation watermark: %w", err) + } + if !t.Valid { + return nil, nil + } + return &t.Time, nil +} + +// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。 +// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。 +func (r *channelMonitorRepository) UpdateAggregationWatermark(ctx context.Context, date time.Time) error { + const q = ` + INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at) + VALUES (1, $1::date, NOW()) + ON CONFLICT (id) DO UPDATE SET + last_aggregated_date = EXCLUDED.last_aggregated_date, + updated_at = NOW() + ` + if _, err := r.db.ExecContext(ctx, q, date); err != nil { + return fmt.Errorf("update aggregation watermark: %w", err) + } + return nil +} + +// ---------- helpers ---------- + +func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor { + if row == nil { + return nil + } + extras := row.ExtraModels + if extras == nil { + extras = []string{} + } + headers := row.ExtraHeaders + if headers == nil { + headers = map[string]string{} + } + out := &service.ChannelMonitor{ + ID: row.ID, + Name: row.Name, + Provider: string(row.Provider), + Endpoint: row.Endpoint, + APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密 + PrimaryModel: row.PrimaryModel, + ExtraModels: extras, + GroupName: row.GroupName, + Enabled: row.Enabled, + IntervalSeconds: row.IntervalSeconds, + LastCheckedAt: row.LastCheckedAt, + CreatedBy: row.CreatedBy, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + ExtraHeaders: headers, + BodyOverrideMode: row.BodyOverrideMode, + BodyOverride: row.BodyOverride, + } + if row.TemplateID != nil { + id := *row.TemplateID + out.TemplateID = &id + } + return out +} + +// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致, +// repo 独立一份避免 import 循环。 +func emptyHeadersIfNilRepo(h map[string]string) map[string]string { + if h == nil { + return map[string]string{} + } + return h +} + +// defaultBodyModeRepo 空串归一为 off(同上不循环)。 +func defaultBodyModeRepo(mode string) string { + if mode == "" { + return "off" + } + return mode +} + +func emptySliceIfNil(in []string) []string { + if in == nil { + return []string{} + } + return in +} diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go new file mode 100644 index 0000000000000000000000000000000000000000..845d186b17abd509b371681c3b012ab499f554d9 --- /dev/null +++ b/backend/internal/repository/channel_monitor_template_repo.go @@ -0,0 +1,195 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/channelmonitor" + "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。 +// 与 channelMonitorRepository 分开一个文件,职责清晰。 +type channelMonitorRequestTemplateRepository struct { + client *dbent.Client + db *sql.DB +} + +// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。 +func NewChannelMonitorRequestTemplateRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRequestTemplateRepository { + return &channelMonitorRequestTemplateRepository{client: client, db: db} +} + +// ---------- CRUD ---------- + +func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error { + client := clientFromContext(ctx, r.client) + builder := client.ChannelMonitorRequestTemplate.Create(). + SetName(t.Name). + SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)). + SetDescription(t.Description). + SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)). + SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode)) + if t.BodyOverride != nil { + builder = builder.SetBodyOverride(t.BodyOverride) + } + + created, err := builder.Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil) + } + t.ID = created.ID + t.CreatedAt = created.CreatedAt + t.UpdatedAt = created.UpdatedAt + return nil +} + +func (r *channelMonitorRequestTemplateRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitorRequestTemplate, error) { + row, err := r.client.ChannelMonitorRequestTemplate.Query(). + Where(channelmonitorrequesttemplate.IDEQ(id)). + Only(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil) + } + return entToServiceTemplate(row), nil +} + +func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error { + client := clientFromContext(ctx, r.client) + updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID). + SetName(t.Name). + SetDescription(t.Description). + SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)). + SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode)) + if t.BodyOverride != nil { + updater = updater.SetBodyOverride(t.BodyOverride) + } else { + updater = updater.ClearBodyOverride() + } + updated, err := updater.Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil) + } + t.UpdatedAt = updated.UpdatedAt + return nil +} + +func (r *channelMonitorRequestTemplateRepository) Delete(ctx context.Context, id int64) error { + client := clientFromContext(ctx, r.client) + if err := client.ChannelMonitorRequestTemplate.DeleteOneID(id).Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil) + } + return nil +} + +func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, params service.ChannelMonitorRequestTemplateListParams) ([]*service.ChannelMonitorRequestTemplate, error) { + q := r.client.ChannelMonitorRequestTemplate.Query() + if params.Provider != "" { + q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider))) + } + rows, err := q. + Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list monitor templates: %w", err) + } + out := make([]*service.ChannelMonitorRequestTemplate, 0, len(rows)) + for _, row := range rows { + out = append(out, entToServiceTemplate(row)) + } + return out, nil +} + +// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。 +// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id +// 就被覆盖。走 ent UpdateMany 保留 hooks。 +func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) { + if len(monitorIDs) == 0 { + return 0, nil + } + client := clientFromContext(ctx, r.client) + tpl, err := client.ChannelMonitorRequestTemplate.Query(). + Where(channelmonitorrequesttemplate.IDEQ(id)). + Only(ctx) + if err != nil { + return 0, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil) + } + + updater := client.ChannelMonitor.Update(). + Where( + channelmonitor.TemplateIDEQ(id), + channelmonitor.IDIn(monitorIDs...), + ). + SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)). + SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode)) + if tpl.BodyOverride != nil { + updater = updater.SetBodyOverride(tpl.BodyOverride) + } else { + updater = updater.ClearBodyOverride() + } + + affected, err := updater.Save(ctx) + if err != nil { + return 0, fmt.Errorf("apply template to monitors: %w", err) + } + return int64(affected), nil +} + +// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。 +func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) { + count, err := r.client.ChannelMonitor.Query(). + Where(channelmonitor.TemplateIDEQ(id)). + Count(ctx) + if err != nil { + return 0, fmt.Errorf("count monitors for template %d: %w", id, err) + } + return int64(count), nil +} + +// ListAssociatedMonitors 列出模板关联的所有监控简略字段。 +// ORDER BY name 稳定输出方便前端展示。 +func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx context.Context, id int64) ([]*service.AssociatedMonitorBrief, error) { + rows, err := r.client.ChannelMonitor.Query(). + Where(channelmonitor.TemplateIDEQ(id)). + Order(dbent.Asc(channelmonitor.FieldName)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list associated monitors for template %d: %w", id, err) + } + out := make([]*service.AssociatedMonitorBrief, 0, len(rows)) + for _, row := range rows { + out = append(out, &service.AssociatedMonitorBrief{ + ID: row.ID, + Name: row.Name, + Provider: string(row.Provider), + Enabled: row.Enabled, + }) + } + return out, nil +} + +// ---------- helpers ---------- + +func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate { + if row == nil { + return nil + } + headers := row.ExtraHeaders + if headers == nil { + headers = map[string]string{} + } + return &service.ChannelMonitorRequestTemplate{ + ID: row.ID, + Name: row.Name, + Provider: string(row.Provider), + Description: row.Description, + ExtraHeaders: headers, + BodyOverrideMode: row.BodyOverrideMode, + BodyOverride: row.BodyOverride, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c17e3365d3ce70e0b0b1482330bc9083a3590eae..5e16475a3a14a69b5ff5e99f3215bffd5a696865 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 9cf3b3920fb3393844d5d0fe6798df5c6e59f402..6dbb9fbd7c01b3708e9e3dcfd1e49d7f52c8e017 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -51,28 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond const nonTransactionalMigrationSuffix = "_notx.sql" +const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql" +const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique" type migrationChecksumCompatibilityRule struct { fileChecksum string acceptedDBChecksum map[string]struct{} + acceptedChecksums map[string]struct{} } // migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 -// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行, +// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。 var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ - "054_drop_legacy_cache_columns.sql": { - fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", - acceptedDBChecksum: map[string]struct{}{ - "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, - }, - }, - "061_add_usage_log_request_type.sql": { - fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", - acceptedDBChecksum: map[string]struct{}{ - "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {}, - "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {}, - }, - }, + "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), + "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), + "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), + "110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"), + "112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"), + "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"), + "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"), + "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"), + "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), + "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), } // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 @@ -199,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { } if nonTx { + if err := prepareNonTransactionalMigration(ctx, db, name); err != nil { + return fmt.Errorf("prepare migration %s: %w", name, err) + } + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 statements := splitSQLStatements(content) @@ -248,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return nil } +func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error { + switch name { + case paymentOrdersOutTradeNoUniqueMigration: + return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db) + default: + return nil + } +} + +func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error { + duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db) + if err != nil { + return fmt.Errorf("precheck duplicate out_trade_no: %w", err) + } + if len(duplicates) > 0 { + return fmt.Errorf( + "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s", + paymentOrdersOutTradeNoUniqueMigration, + strings.Join(duplicates, ", "), + ) + } + + invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex) + if err != nil { + return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + if !invalid { + return nil + } + + if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil { + return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err) + } + return nil +} + +func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) { + rows, err := db.QueryContext(ctx, ` + SELECT out_trade_no, COUNT(*) AS duplicate_count + FROM payment_orders + WHERE out_trade_no <> '' + GROUP BY out_trade_no + HAVING COUNT(*) > 1 + ORDER BY duplicate_count DESC, out_trade_no + LIMIT 5 + `) + if err != nil { + return nil, err + } + defer func() { + _ = rows.Close() + }() + + duplicates := make([]string, 0, 5) + for rows.Next() { + var outTradeNo string + var duplicateCount int + if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil { + return nil, err + } + duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount)) + } + if err := rows.Err(); err != nil { + return nil, err + } + return duplicates, nil +} + +func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) { + var invalid bool + err := db.QueryRowContext(ctx, ` + SELECT EXISTS ( + SELECT 1 + FROM pg_class idx + JOIN pg_namespace ns ON ns.oid = idx.relnamespace + JOIN pg_index i ON i.indexrelid = idx.oid + WHERE ns.nspname = 'public' + AND idx.relname = $1 + AND NOT i.indisvalid + ) + `, indexName).Scan(&invalid) + return invalid, err +} + func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error { hasLegacy, err := tableExists(ctx, db, "schema_migrations") if err != nil { @@ -322,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func checksumSet(values ...string) map[string]struct{} { + out := make(map[string]struct{}, len(values)) + for _, value := range values { + out[value] = struct{}{} + } + return out +} + +func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule { + return migrationChecksumCompatibilityRule{ + fileChecksum: fileChecksum, + acceptedDBChecksum: checksumSet(acceptedDBChecksums...), + acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...), + } +} + func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { rule, ok := migrationChecksumCompatibilityRules[name] if !ok { return false } - if rule.fileChecksum != fileChecksum { + _, dbOK := rule.acceptedChecksums[dbChecksum] + if !dbOK { return false } - _, ok = rule.acceptedDBChecksum[dbChecksum] - return ok + _, fileOK := rule.acceptedChecksums[fileChecksum] + return fileOK } func validateMigrationExecutionMode(name, content string) (bool, error) { diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go index 6c3ad725fa541a4e79fb4360ee8f01baceb30e9e..1fcb3be1e7082969d459a2b5dd70c7234522d10c 100644 --- a/backend/internal/repository/migrations_runner_checksum_test.go +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -51,4 +51,114 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { ) require.False(t, ok) }) + + t.Run("109历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", + ) + require.True(t, ok) + }) + + t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", + ) + require.True(t, ok) + }) + + t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "109_auth_identity_compat_backfill.sql", + "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", + "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee", + ) + require.True(t, ok) + }) + + t.Run("110历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "110_pending_auth_and_provider_default_grants.sql", + "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925", + "32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", + ) + require.True(t, ok) + }) + + t.Run("112历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "112_add_payment_order_provider_key_snapshot.sql", + "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e", + "b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", + ) + require.True(t, ok) + }) + + t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "115_auth_identity_legacy_external_backfill.sql", + "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f", + "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", + ) + require.True(t, ok) + }) + + t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "116_auth_identity_legacy_external_safety_reports.sql", + "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877", + "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", + ) + require.True(t, ok) + }) + + t.Run("119历史checksum可兼容占位文件", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "119_enforce_payment_orders_out_trade_no_unique.sql", + "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34", + "0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", + ) + require.True(t, ok) + }) + + t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) { + for _, dbChecksum := range []string{ + "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb", + "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", + } { + ok := isMigrationChecksumCompatible( + "118_wechat_dual_mode_and_auth_source_defaults.sql", + dbChecksum, + "b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", + ) + require.True(t, ok) + } + }) + + t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) { + for _, dbChecksum := range []string{ + "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", + "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", + "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a", + } { + ok := isMigrationChecksumCompatible( + "120_enforce_payment_orders_out_trade_no_unique_notx.sql", + dbChecksum, + "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", + ) + require.True(t, ok) + } + }) + + t.Run("119未知checksum不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "119_enforce_payment_orders_out_trade_no_unique.sql", + "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + ) + require.False(t, ok) + }) } diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go index 9f8a94c6ee061d1dc2f21d09beb2c5a87b3d1d3a..5d67665ed456253fb91a6b0cecf6e57da990544e 100644 --- a/backend/internal/repository/migrations_runner_extra_test.go +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) } +func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { + for _, name := range []string{ + "109_auth_identity_compat_backfill.sql", + "110_pending_auth_and_provider_default_grants.sql", + "112_add_payment_order_provider_key_snapshot.sql", + "115_auth_identity_legacy_external_backfill.sql", + "116_auth_identity_legacy_external_safety_reports.sql", + "118_wechat_dual_mode_and_auth_source_defaults.sql", + "120_enforce_payment_orders_out_trade_no_unique_notx.sql", + "123_fix_legacy_auth_source_grant_on_signup_defaults.sql", + } { + rule, ok := migrationChecksumCompatibilityRules[name] + require.Truef(t, ok, "missing compatibility rule for %s", name) + require.NotEmpty(t, rule.fileChecksum) + require.NotEmpty(t, rule.acceptedDBChecksum) + } +} + func TestEnsureAtlasBaselineAligned(t *testing.T) { t.Run("skip_when_no_legacy_table", func(t *testing.T) { db, mock, err := sqlmock.New() diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go index db1183cdbd95dbe60ae57f335c97b10d5215791d..b7cb396c470826c6718c892d77462be159f2c197 100644 --- a/backend/internal/repository/migrations_runner_notx_test.go +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); require.NoError(t, mock.ExpectationsWereMet()) } +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "duplicate out_trade_no") + require.Contains(t, err.Error(), "dup-out-trade-no") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders"). + WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"})) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("paymentorder_out_trade_no_unique"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{ + Data: []byte(` +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index dd3019bbd5932e3cde9b2abe9679521c10015aa7..eeee5c23f407bcf379dc2daea9bb83c50e680aed 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) { + tx := testTx(t) + + requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false) + requireColumn(t, tx, "users", "signup_source", "character varying", 20, false) + requireColumnDefaultContains(t, tx, "users", "signup_source", "email") + requireConstraintDefinitionContains( + t, + tx, + "users", + "users_signup_source_check", + "signup_source", + "'email'", + "'linuxdo'", + "'wechat'", + "'oidc'", + ) + + requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE") + requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE") + requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL") + requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE") + requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL") + + requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no") + requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE") + requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique") +} + func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { t.Helper() @@ -106,6 +135,118 @@ SELECT EXISTS ( require.True(t, exists, "expected index %s on %s", index, table) } +func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.False(t, exists, "expected index %s on %s to be absent", index, table) +} + +func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) { + t.Helper() + + var ( + unique bool + def string + ) + + err := tx.QueryRowContext(context.Background(), ` +SELECT + i.indisunique, + pg_get_indexdef(i.indexrelid) +FROM pg_class idx +JOIN pg_index i ON i.indexrelid = idx.oid +JOIN pg_class tbl ON tbl.oid = i.indrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE ns.nspname = 'public' + AND tbl.relname = $1 + AND idx.relname = $2 +`, table, index).Scan(&unique, &def) + require.NoError(t, err, "query index definition for %s.%s", table, index) + require.True(t, unique, "expected index %s on %s to be unique", index, table) + + for _, fragment := range fragments { + require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment) + } +} + +func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) { + t.Helper() + + var actual string + err := tx.QueryRowContext(context.Background(), ` +SELECT CASE c.confdeltype + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' +END +FROM pg_constraint c +JOIN pg_class tbl ON tbl.oid = c.conrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid +JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey) +WHERE ns.nspname = 'public' + AND c.contype = 'f' + AND tbl.relname = $1 + AND attr.attname = $2 + AND ref_tbl.relname = $3 +LIMIT 1 +`, table, column, refTable).Scan(&actual) + require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable) + require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable) +} + +func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) { + t.Helper() + + var def string + err := tx.QueryRowContext(context.Background(), ` +SELECT pg_get_constraintdef(c.oid) +FROM pg_constraint c +JOIN pg_class tbl ON tbl.oid = c.conrelid +JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE ns.nspname = 'public' + AND tbl.relname = $1 + AND c.conname = $2 +`, table, constraint).Scan(&def) + require.NoError(t, err, "query constraint definition for %s.%s", table, constraint) + + for _, fragment := range fragments { + require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment) + } +} + +func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) { + t.Helper() + + var columnDefault sql.NullString + err := tx.QueryRowContext(context.Background(), ` +SELECT column_default +FROM information_schema.columns +WHERE table_schema = 'public' + AND table_name = $1 + AND column_name = $2 +`, table, column).Scan(&columnDefault) + require.NoError(t, err, "query column_default for %s.%s", table, column) + require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column) + + for _, fragment := range fragments { + require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment) + } +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/openai_403_counter_cache.go b/backend/internal/repository/openai_403_counter_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..a68d2518de37952ea65ffd476bccc057ef36988f --- /dev/null +++ b/backend/internal/repository/openai_403_counter_cache.go @@ -0,0 +1,51 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const openAI403CounterPrefix = "openai_403_count:account:" + +var openAI403CounterIncrScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + local count = redis.call('INCR', key) + if count == 1 then + redis.call('EXPIRE', key, ttl) + end + + return count +`) + +type openAI403CounterCache struct { + rdb *redis.Client +} + +func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache { + return &openAI403CounterCache{rdb: rdb} +} + +func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + + ttlSeconds := windowMinutes * 60 + if ttlSeconds < 60 { + ttlSeconds = 60 + } + + result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64() + if err != nil { + return 0, fmt.Errorf("increment openai 403 count: %w", err) + } + return result, nil +} + +func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index dca0b612fb1728cb06fc1a0c53dedf28504efabd..acb270a3b33f5b491c2c7d247d25d506fb158665 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -2,6 +2,7 @@ package repository import ( "context" + "errors" "net/http" "net/url" "strings" @@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie Post(s.tokenURL) if err != nil { + if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) { + return nil, newOpenAINoProxyHintError(err) + } return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err) } @@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre Post(s.tokenURL) if err != nil { + if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) { + return nil, newOpenAINoProxyHintError(err) + } return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err) } @@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) { Timeout: 120 * time.Second, }) } + +func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool { + if strings.TrimSpace(proxyURL) != "" || err == nil { + return false + } + if ctx != nil && ctx.Err() != nil { + return false + } + return !errors.Is(err, context.Canceled) +} + +func newOpenAINoProxyHintError(cause error) error { + return infraerrors.New( + http.StatusBadGateway, + "OPENAI_OAUTH_PROXY_REQUIRED", + "OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.", + ).WithCause(cause) +} diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index c1901d71aaea3d98e2b21ca3e7091eea80e69af1..b43e2b52fc76641e301448013921b9df64112e47 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -8,6 +8,7 @@ import ( "net/url" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { require.ErrorContains(s.T(), err, "request failed") } +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + s.srv.Close() + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") + + require.Error(s.T(), err) + require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err)) + require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured") +} + func (s *OpenAIOAuthServiceSuite) TestContextCancel() { started := make(chan struct{}) block := make(chan struct{}) diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go index 2b6edad3776fb460c52e8476c56b6275ee827d13..62f48b58f4c6e8f9173a683547097bd4b1b14507 100644 --- a/backend/internal/repository/usage_billing_repo.go +++ b/backend/internal/repository/usage_billing_repo.go @@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI if err != nil { return nil, err } - defer func() { _ = rows.Close() }() var state service.AccountQuotaState if rows.Next() { @@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI &state.DailyUsed, &state.DailyLimit, &state.WeeklyUsed, &state.WeeklyLimit, ); err != nil { + _ = rows.Close() return nil, err } } else { if err := rows.Err(); err != nil { + _ = rows.Close() return nil, err } + _ = rows.Close() return nil, service.ErrAccountNotFound } if err := rows.Err(); err != nil { + _ = rows.Close() return nil, err } - if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit { + // 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上 + // 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回 + // "unexpected Parse response" 错误。 + if err := rows.Close(); err != nil { + return nil, err + } + // 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照, + // 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号, + // 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。 + // 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount), + // 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。 + crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit + crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit + crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit + if crossedTotal || crossedDaily || crossedWeekly { if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) return nil, err diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go index eda34cc908c951bcdc5af5f5264b294b2d68d311..e8d4d32707fbd3393c1f966303a2117cf43165aa 100644 --- a/backend/internal/repository/usage_billing_repo_integration_test.go +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { require.InDelta(t, 3.5, quotaUsed, 0.000001) } +func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + newFixture := func(t *testing.T, extra map[string]any) (int64, int64) { + t.Helper() + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-outbox-" + uuid.NewString(), + Name: "billing-outbox", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-outbox-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: extra, + }) + return apiKey.ID, account.ID + } + + outboxCountFor := func(t *testing.T, accountID int64) int { + t.Helper() + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, + "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2", + service.SchedulerOutboxEventAccountChanged, accountID, + ).Scan(&count)) + return count + } + + t.Run("daily_first_crossing_enqueues", func(t *testing.T) { + apiKeyID, accountID := newFixture(t, map[string]any{ + "quota_daily_limit": 10.0, + }) + // 第一次低于日限额:不应入队 outbox + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 4, + }) + require.NoError(t, err) + require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue") + + // 第二次跨越日限额:应入队一次 outbox + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 8, + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once") + + // 再次递增(已超):不应重复入队 + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 2, + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue") + }) + + t.Run("weekly_first_crossing_enqueues", func(t *testing.T) { + apiKeyID, accountID := newFixture(t, map[string]any{ + "quota_weekly_limit": 10.0, + }) + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 15, // 单次即跨越 + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once") + }) +} + func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { ctx := context.Background() repo := newDashboardAggregationRepositoryWithSQL(integrationDB) diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eca5313f6da095c7de7d1afd4dff71eabc977a94..74d25cb0b6f98357eacf178ab6b9b8aa0a57b7f3 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -13,14 +13,14 @@ type userGroupRateRepository struct { sql sqlExecutor } -// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储 func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { return &userGroupRateRepository{sql: sqlDB} } -// GetByUserID 获取用户的所有专属分组倍率 +// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { - query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL` rows, err := r.sql.QueryContext(ctx, query, userID) if err != nil { return nil, err @@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } -// GetByUserIDs 批量获取多个用户的专属分组倍率。 -// 返回结构:map[userID]map[groupID]rate +// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { result := make(map[int64]map[int64]float64, len(userIDs)) if len(userIDs) == 0 { @@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in rows, err := r.sql.QueryContext(ctx, ` SELECT user_id, group_id, rate_multiplier FROM user_group_rate_multipliers - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL `, pq.Array(uniqueIDs)) if err != nil { return nil, err @@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in return result, nil } -// GetByGroupID 获取指定分组下所有用户的专属倍率 +// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { query := ` - SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier + SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override FROM user_group_rate_multipliers ugr JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL WHERE ugr.group_id = $1 @@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 var result []service.UserGroupRateEntry for rows.Next() { var entry service.UserGroupRateEntry - if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { + var rate sql.NullFloat64 + var rpm sql.NullInt32 + if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil { return nil, err } + if rate.Valid { + v := rate.Float64 + entry.RateMultiplier = &v + } + if rpm.Valid { + v := int(rpm.Int32) + entry.RPMOverride = &v + } result = append(result, entry) } if err := rows.Err(); err != nil { @@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 return result, nil } -// GetByUserAndGroup 获取用户在特定分组的专属倍率 +// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` - var rate float64 + var rate sql.NullFloat64 err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) if err == sql.ErrNoRows { return nil, nil @@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, if err != nil { return nil, err } - return &rate, nil + if !rate.Valid { + return nil, nil + } + v := rate.Float64 + return &v, nil +} + +// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) +func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) { + query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rpm sql.NullInt32 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if !rpm.Valid { + return nil, nil + } + v := int(rpm.Int32) + return &v, nil } -// SyncUserGroupRates 同步用户的分组专属倍率 +// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。 +// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。 +// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。 +// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。 func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { if len(rates) == 0 { - // 如果传入空 map,删除该用户的所有专属倍率 - _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 + `, userID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID) return err } - // 分离需要删除和需要 upsert 的记录 - var toDelete []int64 + var clearGroupIDs []int64 upsertGroupIDs := make([]int64, 0, len(rates)) upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { - toDelete = append(toDelete, groupID) + clearGroupIDs = append(clearGroupIDs, groupID) } else { upsertGroupIDs = append(upsertGroupIDs, groupID) upsertRates = append(upsertRates, *rate) } } - // 删除指定的记录 - if len(toDelete) > 0 { + if len(clearGroupIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 AND group_id = ANY($2) + `, userID, pq.Array(clearGroupIDs)); err != nil { + return err + } if _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, - userID, pq.Array(toDelete)); err != nil { + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID, pq.Array(clearGroupIDs)); err != nil { return err } } - // Upsert 记录 - now := time.Now() if len(upsertGroupIDs) > 0 { + now := time.Now() _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) SELECT @@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID return nil } -// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) +// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。 +// 语义: +// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。 +// - 出现的用户行:upsert rate_multiplier。 func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { - if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { + keepUserIDs := make([]int64, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + } + + // 未在 entries 列表中的行:清空 rate_multiplier。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { return err } + if len(entries) == 0 { return nil } + userIDs := make([]int64, len(entries)) rates := make([]float64, len(entries)) for i, e := range entries { @@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, return err } -// DeleteByGroupID 删除指定分组的所有用户专属倍率 +// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。 +// 语义: +// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。 +// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。 +func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error { + keepUserIDs := make([]int64, 0, len(entries)) + var clearUserIDs []int64 + upsertUserIDs := make([]int64, 0, len(entries)) + upsertValues := make([]int32, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + if e.RPMOverride == nil { + clearUserIDs = append(clearUserIDs, e.UserID) + } else { + upsertUserIDs = append(upsertUserIDs, e.UserID) + upsertValues = append(upsertValues, int32(*e.RPMOverride)) + } + } + + // 未在 entries 列表中的行:清空 rpm_override。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 显式 clear 的行。 + if len(clearUserIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id = ANY($2) + `, groupID, pq.Array(clearUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { + return err + } + + if len(upsertUserIDs) > 0 { + now := time.Now() + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at) + SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz + FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override) + ON CONFLICT (user_id, group_id) + DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at + `, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues)) + if err != nil { + return err + } + } + + return nil +} + +// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。 +func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID) + return err +} + +// DeleteByGroupID 删除指定分组的所有用户专属条目 func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) return err } -// DeleteByUserID 删除指定用户的所有专属倍率 +// DeleteByUserID 删除指定用户的所有专属条目 func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) return err diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go new file mode 100644 index 0000000000000000000000000000000000000000..b2b03746d0d82ccc292bcdf81e7926cd6d8d5e7b --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -0,0 +1,880 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "hash/fnv" + "reflect" + "sort" + "strings" + "sync" + "time" + "unsafe" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +var ( + ErrAuthIdentityOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_OWNERSHIP_CONFLICT", + "auth identity already belongs to another user", + ) + ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", + "auth identity channel already belongs to another user", + ) + ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest( + "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH", + "auth identity channel provider must match canonical identity", + ) +) + +type ProviderGrantReason string + +const ( + ProviderGrantReasonSignup ProviderGrantReason = "signup" + ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind" +) + +type AuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type AuthIdentityChannelKey struct { + ProviderType string + ProviderKey string + Channel string + ChannelAppID string + ChannelSubject string +} + +type CreateAuthIdentityInput struct { + UserID int64 + Canonical AuthIdentityKey + Channel *AuthIdentityChannelKey + Issuer *string + VerifiedAt *time.Time + Metadata map[string]any + ChannelMetadata map[string]any +} + +type BindAuthIdentityInput = CreateAuthIdentityInput + +type CreateAuthIdentityResult struct { + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey { + if r == nil || r.Identity == nil { + return AuthIdentityKey{} + } + return AuthIdentityKey{ + ProviderType: r.Identity.ProviderType, + ProviderKey: r.Identity.ProviderKey, + ProviderSubject: r.Identity.ProviderSubject, + } +} + +func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey { + if r == nil || r.Channel == nil { + return nil + } + return &AuthIdentityChannelKey{ + ProviderType: r.Channel.ProviderType, + ProviderKey: r.Channel.ProviderKey, + Channel: r.Channel.Channel, + ChannelAppID: r.Channel.ChannelAppID, + ChannelSubject: r.Channel.ChannelSubject, + } +} + +type UserAuthIdentityLookup struct { + User *dbent.User + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +type ProviderGrantRecordInput struct { + UserID int64 + ProviderType string + GrantReason ProviderGrantReason +} + +type IdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type sqlQueryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +var repositoryScopedKeyLocks = newScopedKeyLockRegistry() + +type scopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*scopedKeyLockEntry +} + +type scopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newScopedKeyLockRegistry() *scopedKeyLockRegistry { + return &scopedKeyLockRegistry{ + locks: make(map[string]*scopedKeyLockEntry), + } +} + +func (r *scopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*scopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &scopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func advisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) { + release := repositoryScopedKeyLocks.lock(keys...) + normalized := normalizeLockKeys(keys...) + if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key)) + if err != nil { + release() + return nil, err + } + _ = rows.Close() + } + return release, nil +} + +func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if dbent.TxFromContext(ctx) != nil { + return fn(ctx) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + return err + } + return tx.Commit() +} + +func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + + client := clientFromContext(ctx, r.client) + + create := client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt) + + identity, err := create.Save(ctx) + if err != nil { + return nil, err + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(ctx) + if err != nil { + return nil, err + } + } + + return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil +} + +func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) { + identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)), + ). + WithUser(). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: identity.Edges.User, + Identity: identity, + }, nil +} + +func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) { + channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: channel.Edges.Identity.Edges.User, + Identity: channel.Edges.Identity, + Channel: channel, + }, nil +} + +func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where(authidentity.UserIDEQ(userID)). + All(ctx) + if err != nil { + return nil, err + } + + records := make([]service.UserAuthIdentityRecord, 0, len(identities)) + for _, identity := range identities { + if identity == nil { + continue + } + records = append(records, service.UserAuthIdentityRecord{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: copyMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + }) + } + + return records, nil +} + +func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" || provider == "email" { + return service.ErrIdentityProviderInvalid + } + + return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + identityIDs, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ(provider), + ). + IDs(txCtx) + if err != nil { + return err + } + if len(identityIDs) == 0 { + return nil + } + + if _, err := client.IdentityAdoptionDecision.Update(). + Where(identityadoptiondecision.IdentityIDIn(identityIDs...)). + ClearIdentityID(). + Save(txCtx); err != nil { + return err + } + if _, err := client.AuthIdentityChannel.Delete(). + Where(authidentitychannel.IdentityIDIn(identityIDs...)). + Exec(txCtx); err != nil { + return err + } + _, err = client.AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ(provider), + ). + Exec(txCtx) + return err + }) +} + +func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + + var result *CreateAuthIdentityResult + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + canonical := input.Canonical + + identityRecords, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), + authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...), + authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), + ). + All(txCtx) + if err != nil { + return err + } + identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID) + if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) { + return ErrAuthIdentityOwnershipConflict + } + if identity == nil { + identity, err = client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt). + Save(txCtx) + if err != nil { + return err + } + } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey) + update := client.AuthIdentity.UpdateOneID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } + if input.Metadata != nil { + update = update.SetMetadata(copyMetadata(input.Metadata)) + } + if input.Issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*input.Issuer)) + } + if input.VerifiedAt != nil { + update = update.SetVerifiedAt(*input.VerifiedAt) + } + identity, err = update.Save(txCtx) + if err != nil { + return err + } + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channelRecords, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), + authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...), + authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), + ). + WithIdentity(). + All(txCtx) + if err != nil { + return err + } + channel = selectOwnedCompatibleChannel(channelRecords, input.UserID) + if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) { + return ErrAuthIdentityChannelOwnershipConflict + } + if channel == nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(txCtx) + if err != nil { + return err + } + } else { + targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey) + update := client.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID) + if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) { + update = update.SetProviderKey(targetProviderKey) + } + if input.ChannelMetadata != nil { + update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) + } + channel, err = update.Save(txCtx) + if err != nil { + return err + } + } + } + + result = &CreateAuthIdentityResult{Identity: identity, Channel: channel} + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func compatibleIdentityProviderKeys(providerType, providerKey string) []string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return []string{providerKey} + } + if providerType != "wechat" { + return []string{providerKey} + } + keys := []string{providerKey} + if !strings.EqualFold(providerKey, "wechat-main") { + keys = append(keys, "wechat-main") + } + if !strings.EqualFold(providerKey, "wechat") { + keys = append(keys, "wechat") + } + return keys +} + +func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func compatibleIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + +func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity + for _, record := range records { + if record.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool { + for _, record := range records { + if record.UserID != userID { + return true + } + } + return false +} + +func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel + for _, record := range records { + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return true + } + } + return false +} + +func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return false, fmt.Errorf("sql executor is not configured") + } + + result, err := exec.ExecContext(ctx, ` +INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + input.UserID, + strings.TrimSpace(input.ProviderType), + string(input.GrantReason), + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + var result *dbent.IdentityAdoptionDecision + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + releaseLocks, err := lockRepositoryScopedKeys( + txCtx, + client, + txAwareSQLExecutor(txCtx, r.sql, r.client), + identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)..., + ) + if err != nil { + return err + } + defer releaseLocks() + + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(txCtx); err != nil { + return err + } + } + + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { + create = create.SetIdentityID(*input.IdentityID) + } + + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return err + } + + result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID) + return err + }) + if err != nil { + return nil, err + } + return result, nil +} + +func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID)) + } + return keys +} + +func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { + return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)). + Only(ctx) +} + +func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastLoginAt(loginAt). + Save(ctx) + return err +} + +func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastActiveAt(activeAt). + Save(ctx) + return err +} + +func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + rows, err := exec.QueryContext(ctx, ` +SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 +FROM user_avatars +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + _, err = exec.ExecContext(ctx, ` +INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) +ON CONFLICT (user_id) DO UPDATE SET + storage_provider = EXCLUDED.storage_provider, + storage_key = EXCLUDED.storage_key, + url = EXCLUDED.url, + content_type = EXCLUDED.content_type, + byte_size = EXCLUDED.byte_size, + sha256 = EXCLUDED.sha256, + updated_at = NOW()`, + userID, + strings.TrimSpace(input.StorageProvider), + strings.TrimSpace(input.StorageKey), + strings.TrimSpace(input.URL), + strings.TrimSpace(input.ContentType), + input.ByteSize, + strings.TrimSpace(input.SHA256), + ) + if err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: strings.TrimSpace(input.StorageProvider), + StorageKey: strings.TrimSpace(input.StorageKey), + URL: strings.TrimSpace(input.URL), + ContentType: strings.TrimSpace(input.ContentType), + ByteSize: input.ByteSize, + SHA256: strings.TrimSpace(input.SHA256), + }, nil +} + +func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return err + } + _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID) + return err +} + +func copyMetadata(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error { + if channel == nil { + return nil + } + + canonicalProviderType := strings.TrimSpace(canonical.ProviderType) + canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey) + channelProviderType := strings.TrimSpace(channel.ProviderType) + channelProviderKey := strings.TrimSpace(channel.ProviderKey) + + if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey { + return ErrAuthIdentityChannelProviderMismatch + } + + return nil +} + +func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { + return exec + } + } + if fallback != nil { + return fallback + } + return sqlExecutorFromEntClient(client) +} + +func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + return exec, nil +} + +func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor { + if client == nil { + return nil + } + + clientValue := reflect.ValueOf(client).Elem() + configValue := clientValue.FieldByName("config") + driverValue := configValue.FieldByName("driver") + if !driverValue.IsValid() { + return nil + } + + driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface() + exec, ok := driver.(sqlQueryExecutor) + if !ok { + return nil + } + return exec +} diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d4f9e8b34bdd49d953b4e10bb6e0e39e116bdb7a --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -0,0 +1,578 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" +) + +type UserProfileIdentityRepoSuite struct { + suite.Suite + ctx context.Context + client *dbent.Client + repo *userRepository +} + +func TestUserProfileIdentityRepoSuite(t *testing.T) { + suite.Run(t, new(UserProfileIdentityRepoSuite)) +} + +func (s *UserProfileIdentityRepoSuite) SetupTest() { + s.ctx = context.Background() + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + _, err := integrationDB.ExecContext(s.ctx, ` +TRUNCATE TABLE + identity_adoption_decisions, + auth_identity_channels, + auth_identities, + pending_auth_sessions, + user_provider_default_grants, + user_avatars +RESTART IDENTITY`) + s.Require().NoError(err) +} + +func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User { + s.T().Helper() + + user, err := s.client.User.Create(). + SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())). + SetPasswordHash("test-password-hash"). + SetRole("user"). + SetStatus("active"). + Save(s.ctx) + s.Require().NoError(err) + return user +} + +func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession { + s.T().Helper() + + session, err := s.client.PendingAuthSession.Create(). + SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())). + SetIntent("bind_current_user"). + SetProviderType(key.ProviderType). + SetProviderKey(key.ProviderKey). + SetProviderSubject(key.ProviderSubject). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(s.ctx) + s.Require().NoError(err) + return session +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() { + user := s.mustCreateUser("canonical-channel") + + verifiedAt := time.Now().UTC().Truncate(time.Second) + created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + Channel: "mp", + ChannelAppID: "wx-app", + ChannelSubject: "openid-123", + }, + Issuer: stringPtr("https://issuer.example"), + VerifiedAt: &verifiedAt, + Metadata: map[string]any{"unionid": "union-123"}, + ChannelMetadata: map[string]any{"openid": "openid-123"}, + }) + s.Require().NoError(err) + s.Require().NotNil(created.Identity) + s.Require().NotNil(created.Channel) + + canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, canonical.User.ID) + s.Require().Equal(created.Identity.ID, canonical.Identity.ID) + s.Require().Equal("union-123", canonical.Identity.ProviderSubject) + + channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, channel.User.ID) + s.Require().Equal(created.Identity.ID, channel.Identity.ID) + s.Require().Equal(created.Channel.ID, channel.Channel.ID) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() { + owner := s.mustCreateUser("owner") + other := s.mustCreateUser("other") + + first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "first"}, + ChannelMetadata: map[string]any{"scope": "read"}, + }) + s.Require().NoError(err) + + second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "second"}, + ChannelMetadata: map[string]any{"scope": "write"}, + }) + s.Require().NoError(err) + s.Require().Equal(first.Identity.ID, second.Identity.ID) + s.Require().Equal(first.Channel.ID, second.Channel.ID) + s.Require().Equal("second", second.Identity.Metadata["username"]) + s.Require().Equal("write", second.Channel.Metadata["scope"]) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-2", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() { + user := s.mustCreateUser("wechat-legacy-alias") + + legacyIdentity, err := s.client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy-alias"}). + Save(s.ctx) + s.Require().NoError(err) + + legacyChannel, err := s.client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("oa"). + SetChannelAppID("wx-app-legacy"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy-alias"}). + Save(s.ctx) + s.Require().NoError(err) + + bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + Channel: "oa", + ChannelAppID: "wx-app-legacy", + ChannelSubject: "openid-legacy-123", + }, + Metadata: map[string]any{"source": "canonical-bind"}, + ChannelMetadata: map[string]any{"scene": "canonical-bind"}, + }) + s.Require().NoError(err) + s.Require().NotNil(bound) + s.Require().NotNil(bound.Identity) + s.Require().NotNil(bound.Channel) + s.Require().Equal(legacyIdentity.ID, bound.Identity.ID) + s.Require().Equal(legacyChannel.ID, bound.Channel.ID) + s.Require().Equal("wechat-main", bound.Identity.ProviderKey) + s.Require().Equal("wechat-main", bound.Channel.ProviderKey) + s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"]) + s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"]) + + identityCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, identityCount) + + channelCount, err := s.client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("oa"), + authidentitychannel.ChannelAppIDEQ("wx-app-legacy"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, channelCount) +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-create") + + _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-create-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "app-mismatch", + ChannelSubject: "openid-create-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-bind") + + _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-bind-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-legacy", + Channel: "oa", + ChannelAppID: "wx-app-bind-mismatch", + ChannelSubject: "openid-bind-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { + user := s.mustCreateUser("tx-rollback") + expectedErr := errors.New("rollback") + + err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }, + }) + s.Require().NoError(err) + + inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "oidc", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + return expectedErr + }) + s.Require().ErrorIs(err, expectedErr) + + _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }) + s.Require().True(dbent.IsNotFound(err)) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`, + user.ID, + "oidc", + string(ProviderGrantReasonFirstBind), + ).Scan(&count)) + s.Require().Zero(count) +} + +func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() { + user := s.mustCreateUser("grant") + + inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().False(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonSignup, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2`, + user.ID, + "wechat", + ).Scan(&count)) + s.Require().Equal(2, count) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() { + user := s.mustCreateUser("adoption") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + s.Require().NoError(err) + + session := s.mustCreatePendingAuthSession(identity.IdentityRef()) + + first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().True(first.AdoptDisplayName) + s.Require().False(first.AdoptAvatar) + s.Require().Nil(first.IdentityID) + + second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().Equal(first.ID, second.ID) + s.Require().NotNil(second.IdentityID) + s.Require().Equal(identity.Identity.ID, *second.IdentityID) + s.Require().True(second.AdoptAvatar) + + loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID) + s.Require().NoError(err) + s.Require().Equal(second.ID, loaded.ID) + s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() { + user := s.mustCreateUser("adoption-reassign") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption-reassign", + }, + }) + s.Require().NoError(err) + + firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: firstSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().NotNil(firstDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID) + + secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: secondSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().NotNil(secondDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID) + + reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID) + s.Require().NoError(err) + s.Require().Nil(reloadedFirst.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() { + user := s.mustCreateUser("avatar-only-update") + + model, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(model) + + err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + if err != nil { + return err + } + return s.repo.Update(txCtx, model) + }) + s.Require().NoError(err) + + avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(avatar) + s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL) +} + +func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { + user := s.mustCreateUser("avatar") + + inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: "data:image/png;base64,QUJD", + ContentType: "image/png", + ByteSize: 3, + SHA256: "902fbdd2b1df0c4f70b4a5d23525e932", + }) + s.Require().NoError(err) + s.Require().Equal("inline", inlineAvatar.StorageProvider) + s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL) + + loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("image/png", loadedAvatar.ContentType) + s.Require().Equal(3, loadedAvatar.ByteSize) + + _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + s.Require().NoError(err) + + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("remote_url", loadedAvatar.StorageProvider) + s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL) + s.Require().Zero(loadedAvatar.ByteSize) + + s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID)) + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Nil(loadedAvatar) +} + +func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() { + user := s.mustCreateUser("activity") + loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + activeAt := loginAt.Add(5 * time.Minute) + + s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt)) + s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt)) + + var storedLoginAt sqlNullTime + var storedActiveAt sqlNullTime + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT last_login_at, last_active_at +FROM users +WHERE id = $1`, + user.ID, + ).Scan(&storedLoginAt, &storedActiveAt)) + s.Require().True(storedLoginAt.Valid) + s.Require().True(storedActiveAt.Valid) + s.Require().True(storedLoginAt.Time.Equal(loginAt)) + s.Require().True(storedActiveAt.Time.Equal(activeAt)) +} + +type sqlNullTime struct { + Time time.Time + Valid bool +} + +func (t *sqlNullTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + t.Time = v + t.Valid = true + return nil + case nil: + t.Time = time.Time{} + t.Valid = false + return nil + default: + return fmt.Errorf("unsupported scan type %T", value) + } +} + +func stringPtr(v string) *string { + return &v +} diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..689f32f9b1e711ea2345815d9773ac0818aac2e2 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go @@ -0,0 +1,212 @@ +package repository + +import ( + "context" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "wechat-legacy@example.com", + Username: "wechat-legacy", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + legacyChannel, err := client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("oa"). + SetChannelAppID("wx-app-legacy"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy-alias"}). + Save(ctx) + require.NoError(t, err) + + bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + Channel: "oa", + ChannelAppID: "wx-app-legacy", + ChannelSubject: "openid-legacy-123", + }, + Metadata: map[string]any{"source": "canonical-bind"}, + ChannelMetadata: map[string]any{"scene": "canonical-bind"}, + }) + require.NoError(t, err) + require.NotNil(t, bound) + require.NotNil(t, bound.Identity) + require.NotNil(t, bound.Channel) + require.Equal(t, legacyIdentity.ID, bound.Identity.ID) + require.Equal(t, legacyChannel.ID, bound.Channel.ID) + require.Equal(t, "wechat-main", bound.Identity.ProviderKey) + require.Equal(t, "wechat-main", bound.Channel.ProviderKey) + + reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey) + require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"]) + + reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", reloadedChannel.ProviderKey) + require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("oa"), + authidentitychannel.ChannelAppIDEQ("wx-app-legacy"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, channelCount) +} + +func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + user := &service.User{ + Email: "repo-adoption@example.com", + Username: "repo-adoption", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, user)) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-repo-adoption"). + SetIntent("bind_current_user"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-repo-adoption"). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(ctx) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) + require.True(t, loaded.AdoptDisplayName) + require.True(t, loaded.AdoptAvatar) +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 913e1c4000595b454f22e3f52d93d06f24aca4e1..d1f10cbdcce57f164512e95bbb4a3d1667ea0452 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -11,12 +11,17 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" entsql "entgo.io/ent/dialect/sql" ) @@ -47,12 +52,33 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { - // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 - txClient = r.client + // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } + } + + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil { + return err } created, err := txClient.User.Create(). @@ -64,12 +90,19 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - Save(ctx) + SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). + SetNillableLastLoginAt(userIn.LastLoginAt). + SetNillableLastActiveAt(userIn.LastActiveAt). + SetRpmLimit(userIn.RPMLimit). + Save(txCtx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil { + return err + } + if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { return err } @@ -101,10 +134,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { - m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound } + if len(matches) > 1 { + return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email)) + } + m := matches[0] out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) @@ -129,14 +172,41 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error } var txClient *dbent.Client + txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() txClient = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) } else { - // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。 - txClient = r.client + // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } + } + + releaseEmailLock, err := lockRepositoryScopedKeys( + txCtx, + txClient, + txAwareSQLExecutor(txCtx, r.sql, r.client), + normalizedEmailUniquenessLockKey(userIn.Email), + ) + if err != nil { + return err + } + defer releaseEmailLock() + + if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil { + return err } + existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + oldEmail := existing.Email + updateOp := txClient.User.UpdateOneID(userIn.ID). SetEmail(userIn.Email). SetUsername(userIn.Username). @@ -150,16 +220,29 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). - SetTotalRecharged(userIn.TotalRecharged) + SetTotalRecharged(userIn.TotalRecharged). + SetRpmLimit(userIn.RPMLimit) + if userIn.SignupSource != "" { + updateOp = updateOp.SetSignupSource(userIn.SignupSource) + } + if userIn.LastLoginAt != nil { + updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt) + } + if userIn.LastActiveAt != nil { + updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt) + } if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } - updated, err := updateOp.Save(ctx) + updated, err := updateOp.Save(txCtx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) } - if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { + if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil { + return err + } + if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { return err } @@ -173,14 +256,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error return nil } +func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error { + client = clientFromContext(ctx, client) + if client == nil || userID <= 0 { + return nil + } + + subject := normalizeEmailAuthIdentitySubject(email) + if subject == "" { + return nil + } + + if err := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(subject). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": source}). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return err + } + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(subject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity.UserID != userID { + return ErrAuthIdentityOwnershipConflict + } + return nil +} + +func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error { + newSubject := normalizeEmailAuthIdentitySubject(newEmail) + if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil { + return err + } + + oldSubject := normalizeEmailAuthIdentitySubject(oldEmail) + if oldSubject == "" || oldSubject == newSubject { + return nil + } + + _, err := clientFromContext(ctx, client).AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(oldSubject), + ). + Exec(ctx) + return err +} + +func normalizeEmailAuthIdentitySubject(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" { + return "" + } + if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) { + return "" + } + return normalized +} + func (r *userRepository) Delete(ctx context.Context, id int64) error { - affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx) + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + txClient = existingTx.Client() + } else { + txClient = r.client + } + } + + identityIDs, err := txClient.AuthIdentity.Query(). + Where(authidentity.UserIDEQ(id)). + IDs(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if len(identityIDs) > 0 { + if _, err := txClient.IdentityAdoptionDecision.Update(). + Where(identityadoptiondecision.IdentityIDIn(identityIDs...)). + ClearIdentityID(). + Save(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if _, err := txClient.AuthIdentityChannel.Delete(). + Where(authidentitychannel.IdentityIDIn(identityIDs...)). + Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if _, err := txClient.AuthIdentity.Delete(). + Where(authidentity.UserIDEQ(id)). + Exec(ctx); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + } + + affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } if affected == 0 { return service.ErrUserNotFound } + + if tx != nil { + if err := tx.Commit(); err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + } return nil } @@ -298,8 +513,13 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + if sortBy == "last_used_at" { + return userLastUsedAtOrder(sortOrder) + } + var field string defaultField := true + nullsLastField := false switch sortBy { case "email": field = dbuser.FieldEmail @@ -322,6 +542,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) case "created_at": field = dbuser.FieldCreatedAt defaultField = false + case "last_active_at": + field = dbuser.FieldLastActiveAt + defaultField = false + nullsLastField = true default: field = dbuser.FieldID } @@ -330,14 +554,92 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(), + dbent.Asc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} } if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(), + dbent.Desc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} } +func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + result := make(map[int64]*time.Time, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + if r.sql == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + const query = ` + SELECT user_id, MAX(created_at) AS last_used_at + FROM usage_logs + WHERE user_id = ANY($1) + GROUP BY user_id + ` + + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var ( + userID int64 + lastUsedAt time.Time + ) + if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil { + return nil, scanErr + } + ts := lastUsedAt.UTC() + result[userID] = &ts + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID}) + if err != nil { + return nil, err + } + return latestByUserID[userID], nil +} + +func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) { + orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) { + return func(s *entsql.Selector) { + subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID)) + s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls)) + s.OrderBy(tieOrder(s.C(dbuser.FieldID))) + } + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){ + orderExpr("ASC", "FIRST", entsql.Asc), + } + } + return []func(*entsql.Selector){ + orderExpr("DESC", "LAST", entsql.Desc), + } +} + // filterUsersByAttributes returns user IDs that match ALL the given attribute filters func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { if len(attrs) == 0 { @@ -436,17 +738,68 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { - return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) + return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) +} + +func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error { + client = clientFromContext(ctx, client) + if client == nil { + return nil + } + + matches, err := client.User.Query(). + Where(userEmailLookupPredicate(email)). + All(ctx) + if err != nil { + return err + } + for _, match := range matches { + if match.ID != userID { + return service.ErrEmailExists + } + } + return nil +} + +func userEmailLookupPredicate(email string) predicate.User { + normalized := normalizeEmailLookupValue(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.P(func(b *entsql.Builder) { + b.WriteString("LOWER(TRIM("). + Ident(s.C(dbuser.FieldEmail)). + WriteString(")) = "). + Arg(normalized) + })) + }) +} + +func normalizeEmailLookupValue(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} + +func normalizedEmailUniquenessLockKey(email string) string { + normalized := normalizeEmailLookupValue(email) + if normalized == "" { + return "" + } + return "users:normalized-email:" + normalized } func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { client := clientFromContext(ctx, r.client) - return client.UserAllowedGroup.Create(). + err := client.UserAllowedGroup.Create(). SetUserID(userID). SetGroupID(groupID). OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). DoNothing(). Exec(ctx) + if isSQLNoRowsError(err) { + return nil + } + return err } func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { @@ -546,6 +899,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). DoNothing(). Exec(ctx); err != nil { + if isSQLNoRowsError(err) { + return nil + } return err } } @@ -558,10 +914,24 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { return } dst.ID = src.ID + dst.SignupSource = src.SignupSource + dst.LastLoginAt = src.LastLoginAt + dst.LastActiveAt = src.LastActiveAt dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } +func userSignupSourceOrDefault(signupSource string) string { + switch strings.TrimSpace(strings.ToLower(signupSource)) { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return strings.TrimSpace(strings.ToLower(signupSource)) + default: + return "email" + } +} + // marshalExtraEmails serializes notify email entries to JSON for storage. func marshalExtraEmails(entries []service.NotifyEmailEntry) string { return service.MarshalNotifyEmails(entries) diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fddd82c5981a220cf541083fa82fc088dd81e55c --- /dev/null +++ b/backend/internal/repository/user_repo_email_identity_integration_test.go @@ -0,0 +1,86 @@ +//go:build integration + +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() { + user := &service.User{ + Email: "repo-create@example.com", + PasswordHash: "test-password-hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 2, + } + + s.Require().NoError(s.repo.Create(s.ctx, user)) + + identity, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("repo-create@example.com"), + ). + Only(s.ctx) + s.Require().NoError(err) + s.Require().Equal(user.ID, identity.UserID) +} + +func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() { + user := &service.User{ + Email: "linuxdo-legacy-user@linuxdo-connect.invalid", + PasswordHash: "test-password-hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 2, + } + + s.Require().NoError(s.repo.Create(s.ctx, user)) + + count, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(count) +} + +func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() { + user := s.mustCreateUser(&service.User{ + Email: "before-update@example.com", + }) + + user.Email = "after-update@example.com" + s.Require().NoError(s.repo.Update(s.ctx, user)) + + newIdentity, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("after-update@example.com"), + ). + Only(s.ctx) + s.Require().NoError(err) + s.Require().Equal(user.ID, newIdentity.UserID) + + oldCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("before-update@example.com"), + ). + Count(context.Background()) + s.Require().NoError(err) + s.Require().Zero(oldCount) +} diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7da3db9b15e94e14e4ab206eced5d8ab240cf100 --- /dev/null +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -0,0 +1,227 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name())) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + db.SetMaxOpenConns(10) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return newUserRepositoryWithSQL(client, db), client +} + +func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + got, err := repo.GetByEmail(ctx, "legacy@example.com") + require.NoError(t, err) + require.Equal(t, " Legacy@Example.com ", got.Email) +} + +func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ") + require.NoError(t, err) + require.True(t, exists) +} + +func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Existing@Example.com ", + Username: "existing-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + err = repo.Create(ctx, &service.User{ + Email: "existing@example.com", + Username: "duplicate-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.ErrorIs(t, err, service.ErrEmailExists) +} + +func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + first := &service.User{ + Email: " Existing@Example.com ", + Username: "existing-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, first)) + + second := &service.User{ + Email: "second@example.com", + Username: "second-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, second)) + + second.Email = " existing@example.com " + err := repo.Update(ctx, second) + require.ErrorIs(t, err, service.ErrEmailExists) +} + +func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + _, err := client.User.Create(). + SetEmail("Conflict@Example.com"). + SetUsername("conflict-user-1"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.User.Create(). + SetEmail(" conflict@example.com "). + SetUsername("conflict-user-2"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = repo.GetByEmail(ctx, "conflict@example.com") + require.Error(t, err) + require.ErrorContains(t, err, "normalized email lookup matched multiple users") +} + +func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) { + repo, client := newUserEntRepo(t) + ctx := context.Background() + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.User.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type createResult struct { + err error + } + + results := make(chan createResult, 2) + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: " Race@Example.com ", + Username: "race-user-1", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + <-firstCreateStarted + + go func() { + results <- createResult{err: repo.Create(ctx, &service.User{ + Email: "race@example.com", + Username: "race-user-2", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + })} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + errors := []error{first.err, second.err} + successes := 0 + conflicts := 0 + for _, err := range errors { + switch err { + case nil: + successes++ + case service.ErrEmailExists: + conflicts++ + default: + t.Fatalf("unexpected create error: %v", err) + } + } + require.Equal(t, 1, successes) + require.Equal(t, 1, conflicts) + + count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) +} diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index f5d0f9ff1893024e7187c2283e62a16c0ad3ad3c..13a605a2f5467dc6685bc1f5b05941d763416d53 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -8,6 +8,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() { s.repo = newUserRepositoryWithSQL(s.client, integrationDB) // 清理测试数据,确保每个测试从干净状态开始 + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels") + _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups") _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users") @@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() { s.Require().Equal(user.ID, got.ID) } +func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() { + user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "}) + + got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ") + s.Require().NoError(err, "GetByEmail normalized lookup") + s.Require().Equal(user.ID, got.ID) +} + func (s *UserRepoSuite) TestGetByEmail_NotFound() { _, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com") s.Require().Error(err, "expected error for non-existent email") } +func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() { + s.mustCreateUser(&service.User{Email: " Legacy@Example.com "}) + + exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ") + s.Require().NoError(err, "ExistsByEmail normalized lookup") + s.Require().True(exists) +} + func (s *UserRepoSuite) TestUpdate() { user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"}) @@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() { s.Require().Equal("updated", updated.Username) } +func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() { + user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"}) + + identityCount, err := s.client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("update-existing-identity@test.com"), + ). + Count(s.ctx) + s.Require().NoError(err) + s.Require().Equal(1, identityCount) + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + got.Username = "updated" + s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows") + + updated, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal("updated", updated.Username) +} + func (s *UserRepoSuite) TestDelete() { user := s.mustCreateUser(&service.User{Email: "delete@test.com"}) @@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() { s.Require().Error(err, "expected error after delete") } +func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() { + user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"}) + + identity, err := s.client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("delete-oauth-subject"). + Save(s.ctx) + s.Require().NoError(err) + + _, err = s.client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("open"). + SetChannelAppID("app-id"). + SetChannelSubject("openid-123"). + Save(s.ctx) + s.Require().NoError(err) + + err = s.repo.Delete(s.ctx, user.ID) + s.Require().NoError(err) + + identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(identityCount) + + channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx) + s.Require().NoError(err) + s.Require().Zero(channelCount) +} + // --- List / ListWithFilters --- func (s *UserRepoSuite) TestList() { diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go index ab84b0e93bb969bfaf1c863cdb30067cc15a8942..3a15bc1024b66033002876732a1047a10ab50595 100644 --- a/backend/internal/repository/user_repo_sort_integration_test.go +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -4,11 +4,30 @@ package repository import ( "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" ) +func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) { + s.T().Helper() + + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID}) + + _, err := integrationDB.ExecContext( + s.ctx, + `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at) + VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`, + userID, + apiKey.ID, + account.ID, + createdAt.UTC(), + ) + s.Require().NoError(err) +} + func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() { s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"}) s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"}) @@ -36,4 +55,110 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { s.Require().Equal(first.ID, users[1].ID) } +func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() { + lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond) + + created := s.mustCreateUser(&service.User{ + Email: "identity-meta@example.com", + SignupSource: "linuxdo", + LastLoginAt: &lastLoginAt, + LastActiveAt: &lastActiveAt, + }) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("linuxdo", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() { + created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"}) + lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond) + + created.SignupSource = "oidc" + created.LastLoginAt = &lastLoginAt + created.LastActiveAt = &lastActiveAt + + s.Require().NoError(s.repo.Update(s.ctx, created)) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("oidc", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() { + earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond) + later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-active@example.com"}) + s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later}) + s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_active_at", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("earlier-active@example.com", users[0].Email) + s.Require().Equal("later-active@example.com", users[1].Email) + s.Require().Equal("nil-active@example.com", users[2].Email) +} + +func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() { + older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second) + newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second) + + userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"}) + userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"}) + s.mustInsertUsageLog(userWithUsage.ID, older) + s.mustInsertUsageLog(userWithUsage.ID, newer) + + got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID}) + s.Require().NoError(err) + s.Require().Contains(got, userWithUsage.ID) + s.Require().NotContains(got, userWithoutUsage.ID) + s.Require().NotNil(got[userWithUsage.ID]) + s.Require().True(got[userWithUsage.ID].Equal(newer)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() { + lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second) + lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second) + lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second) + + nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"}) + wrongSource := s.mustCreateUser(&service.User{ + Email: "active-not-usage@example.com", + LastActiveAt: &lastActiveVeryRecent, + }) + rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"}) + + s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder) + s.mustInsertUsageLog(rightSource.ID, lastUsedNewer) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_used_at", + SortOrder: "desc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal(rightSource.ID, users[0].ID) + s.Require().Equal(wrongSource.ID, users[1].ID) + s.Require().Equal(nilUsage.ID, users[2].ID) +} + func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/repository/user_rpm_cache.go b/backend/internal/repository/user_rpm_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..42bf93326e97d67747acac988b6b9a7cfc9b4752 --- /dev/null +++ b/backend/internal/repository/user_rpm_cache.go @@ -0,0 +1,108 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// 用户/分组级 RPM 计数器 Redis 实现。 +// +// 设计说明: +// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute} +// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。 +// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。 +// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。 +// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。 +const ( + userGroupRPMKeyPrefix = "rpm:ug:" + userRPMKeyPrefix = "rpm:u:" + + userRPMKeyTTL = 120 * time.Second +) + +type userRPMCacheImpl struct { + rdb *redis.Client +} + +// NewUserRPMCache 创建用户/分组级 RPM 计数器。 +func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache { + return &userRPMCacheImpl{rdb: rdb} +} + +// minuteTS 获取当前 Redis 服务端分钟时间戳。 +func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("redis TIME: %w", err) + } + return t.Unix() / 60, nil +} + +// atomicIncr 原子 INCR+EXPIRE。 +func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) { + pipe := c.rdb.TxPipeline() + incr := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, userRPMKeyTTL) + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("user rpm increment: %w", err) + } + return int(incr.Val()), nil +} + +// IncrementUserGroupRPM 递增 (user, group) 分钟计数。 +func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + return c.atomicIncr(ctx, key) +} + +// IncrementUserRPM 递增用户分钟计数。 +func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + return c.atomicIncr(ctx, key) +} + +// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user group rpm get: %w", err) + } + return val, nil +} + +// GetUserRPM 获取用户当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user rpm get: %w", err) + } + return val, nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index d3adb4a0aa7b52daf523627c72f2b4ebdff35523..6d24d312bf450dd0d872e0832e3b1170b1a66b7b 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -89,6 +89,8 @@ var ProviderSet = wire.NewSet( NewErrorPassthroughRepository, NewTLSFingerprintProfileRepository, NewChannelRepository, + NewChannelMonitorRepository, + NewChannelMonitorRequestTemplateRepository, // Cache implementations NewGatewayCache, @@ -96,10 +98,12 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, NewTimeoutCounterCache, + NewOpenAI403CounterCache, NewInternal500CounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, + NewUserRPMCache, NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b686b986faee962e777845a584dc6a553c9128af..e89ef3d9835681a4bf84c06ed3605ead4b86eb88 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -50,10 +50,12 @@ func TestAPIContracts(t *testing.T) { "data": { "id": 1, "email": "alice@example.com", + "email_bound": true, "username": "alice", "role": "user", "balance": 12.5, "concurrency": 5, + "rpm_limit": 0, "status": "active", "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", @@ -63,6 +65,123 @@ func TestAPIContracts(t *testing.T) { "balance_notify_threshold": null, "balance_notify_extra_emails": null, "total_recharged": 0, + "linuxdo_bound": false, + "oidc_bound": false, + "wechat_bound": false, + "identities": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note_key": "profile.authBindings.notes.emailManagedFromProfile", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "identity_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note_key": "profile.authBindings.notes.emailManagedFromProfile", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "auth_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note_key": "profile.authBindings.notes.emailManagedFromProfile", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, "run_mode": "standard" } }`, @@ -215,6 +334,7 @@ func TestAPIContracts(t *testing.T) { "fallback_group_id_on_invalid_request": null, "require_oauth_only": false, "require_privacy_set": false, + "rpm_limit": 0, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -479,7 +599,7 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectUsePKCE: "true", service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectClockSkewSeconds: "120", @@ -500,10 +620,15 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyTableDefaultPageSize: "20", service.SettingKeyTablePageSizeOptions: "[10,20,50,100]", - service.SettingKeyOpsMonitoringEnabled: "false", - service.SettingKeyOpsRealtimeMonitoringEnabled: "true", - service.SettingKeyOpsQueryModeDefault: "auto", - service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingKeyOpsMonitoringEnabled: "false", + service.SettingKeyOpsRealtimeMonitoringEnabled: "true", + service.SettingKeyOpsQueryModeDefault: "auto", + service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay, + service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat, + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + "openai_advanced_scheduler_enabled": "true", }) }, method: http.MethodGet, @@ -549,7 +674,7 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_redirect_url": "", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_token_auth_method": "client_secret_post", - "oidc_connect_use_pkce": false, + "oidc_connect_use_pkce": true, "oidc_connect_validate_id_token": true, "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_clock_skew_seconds": 120, @@ -567,8 +692,30 @@ func TestAPIContracts(t *testing.T) { "api_base_url": "https://api.example.com", "contact_info": "support", "doc_url": "https://docs.example.com", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", @@ -592,6 +739,11 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "easypay_alipay", + "payment_visible_method_wxpay_source": "official_wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -618,7 +770,222 @@ func TestAPIContracts(t *testing.T) { "account_quota_notify_enabled": false, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", - "account_quota_notify_emails": [] + "account_quota_notify_emails": [], + "channel_monitor_enabled": true, + "channel_monitor_default_interval_seconds": 60, + "available_channels_enabled": false, + "wechat_connect_enabled": false, + "wechat_connect_app_id": "", + "wechat_connect_app_secret_configured": false, + "wechat_connect_mode": "open", + "wechat_connect_open_enabled": false, + "wechat_connect_open_app_id": "", + "wechat_connect_open_app_secret_configured": false, + "wechat_connect_mp_enabled": false, + "wechat_connect_mp_app_id": "", + "wechat_connect_mp_app_secret_configured": false, + "wechat_connect_mobile_enabled": false, + "wechat_connect_mobile_app_id": "", + "wechat_connect_mobile_app_secret_configured": false, + "wechat_connect_redirect_url": "", + "wechat_connect_frontend_redirect_url": "/auth/wechat/callback", + "wechat_connect_scopes": "snsapi_login" + } + }`, + }, + { + name: "GET /api/v1/admin/settings falls back to config oauth defaults", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.cfg.OIDC = config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "ConfigOIDC", + ClientID: "oidc-config-client", + ClientSecret: "oidc-config-secret", + IssuerURL: "https://issuer.example.com", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256,ES256,PS256", + ClockSkewSeconds: 120, + } + deps.cfg.WeChat = config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + Mode: "open", + Scopes: "snsapi_login", + FrontendRedirectURL: "/auth/wechat/callback", + } + deps.settingRepo.SetAll(map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", + }) + }, + method: http.MethodGet, + path: "/api/v1/admin/settings", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "registration_enabled": true, + "email_verify_enabled": false, + "registration_email_suffix_whitelist": [], + "promo_code_enabled": true, + "password_reset_enabled": false, + "frontend_url": "", + "invitation_code_enabled": false, + "totp_enabled": false, + "totp_encryption_key_configured": false, + "smtp_host": "", + "smtp_port": 587, + "smtp_username": "", + "smtp_password_configured": false, + "smtp_from_email": "", + "smtp_from_name": "", + "smtp_use_tls": false, + "turnstile_enabled": false, + "turnstile_site_key": "", + "turnstile_secret_key_configured": false, + "linuxdo_connect_enabled": false, + "linuxdo_connect_client_id": "", + "linuxdo_connect_client_secret_configured": false, + "linuxdo_connect_redirect_url": "", + "oidc_connect_enabled": true, + "oidc_connect_provider_name": "ConfigOIDC", + "oidc_connect_client_id": "oidc-config-client", + "oidc_connect_client_secret_configured": true, + "oidc_connect_issuer_url": "https://issuer.example.com", + "oidc_connect_discovery_url": "", + "oidc_connect_authorize_url": "", + "oidc_connect_token_url": "", + "oidc_connect_userinfo_url": "", + "oidc_connect_jwks_url": "", + "oidc_connect_scopes": "openid email profile", + "oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback", + "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", + "oidc_connect_token_auth_method": "client_secret_post", + "oidc_connect_use_pkce": true, + "oidc_connect_validate_id_token": true, + "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", + "oidc_connect_clock_skew_seconds": 120, + "oidc_connect_require_email_verified": false, + "oidc_connect_userinfo_email_path": "", + "oidc_connect_userinfo_id_path": "", + "oidc_connect_userinfo_username_path": "", + "site_name": "Sub2API", + "site_logo": "", + "site_subtitle": "Subscription to API Conversion Platform", + "api_base_url": "", + "contact_info": "", + "doc_url": "", + "home_content": "", + "hide_ccs_import_button": false, + "purchase_subscription_enabled": false, + "purchase_subscription_url": "", + "table_default_page_size": 20, + "table_page_size_options": [10, 20, 50], + "custom_menu_items": [], + "custom_endpoints": [], + "default_concurrency": 0, + "default_balance": 0, + "default_user_rpm_limit": 0, + "default_subscriptions": [], + "enable_model_fallback": false, + "fallback_model_anthropic": "claude-3-5-sonnet-20241022", + "fallback_model_openai": "gpt-4o", + "fallback_model_gemini": "gemini-2.5-pro", + "fallback_model_antigravity": "gemini-2.5-pro", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "ops_monitoring_enabled": false, + "ops_realtime_monitoring_enabled": true, + "ops_query_mode_default": "auto", + "ops_metrics_interval_seconds": 60, + "min_claude_code_version": "", + "max_claude_code_version": "", + "allow_ungrouped_key_scheduling": false, + "backend_mode_enabled": false, + "enable_fingerprint_unification": true, + "enable_metadata_passthrough": false, + "enable_cch_signing": false, + "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "", + "payment_visible_method_wxpay_source": "", + "payment_visible_method_alipay_enabled": false, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": false, + "payment_enabled": false, + "payment_min_amount": 0, + "payment_max_amount": 0, + "payment_daily_limit": 0, + "payment_order_timeout_minutes": 0, + "payment_max_pending_orders": 0, + "payment_enabled_types": null, + "payment_balance_disabled": false, + "payment_balance_recharge_multiplier": 0, + "payment_recharge_fee_rate": 0, + "payment_load_balance_strategy": "", + "payment_product_name_prefix": "", + "payment_product_name_suffix": "", + "payment_help_image_url": "", + "payment_help_text": "", + "payment_cancel_rate_limit_enabled": false, + "payment_cancel_rate_limit_max": 0, + "payment_cancel_rate_limit_window": 0, + "payment_cancel_rate_limit_unit": "", + "payment_cancel_rate_limit_window_mode": "", + "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, + "balance_low_notify_threshold": 0, + "balance_low_notify_recharge_url": "", + "account_quota_notify_emails": [], + "channel_monitor_enabled": true, + "channel_monitor_default_interval_seconds": 60, + "available_channels_enabled": false, + "wechat_connect_enabled": true, + "wechat_connect_app_id": "wx-open-config", + "wechat_connect_app_secret_configured": true, + "wechat_connect_mode": "open", + "wechat_connect_open_enabled": true, + "wechat_connect_open_app_id": "wx-open-config", + "wechat_connect_open_app_secret_configured": true, + "wechat_connect_mp_enabled": false, + "wechat_connect_mp_app_id": "wx-open-config", + "wechat_connect_mp_app_secret_configured": true, + "wechat_connect_mobile_enabled": false, + "wechat_connect_mobile_app_id": "wx-open-config", + "wechat_connect_mobile_app_secret_configured": true, + "wechat_connect_redirect_url": "", + "wechat_connect_frontend_redirect_url": "/auth/wechat/callback", + "wechat_connect_scopes": "snsapi_login", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false } }`, }, @@ -665,6 +1032,7 @@ func TestAPIContracts(t *testing.T) { type contractDeps struct { now time.Time router http.Handler + cfg *config.Config apiKeyRepo *stubApiKeyRepo groupRepo *stubGroupRepo userSubRepo *stubUserSubscriptionRepo @@ -726,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) @@ -785,6 +1153,7 @@ func newContractDeps(t *testing.T) *contractDeps { return &contractDeps{ now: now, router: r, + cfg: cfg, apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userSubRepo: userSubRepo, @@ -858,6 +1227,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } +func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -894,6 +1275,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 return errors.New("not implemented") } +func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + return nil, nil +} + +func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + return nil, nil +} + +func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + return nil +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index ed2578c843f1f9ef38b0a4c6110d0468cadebe6d..06e3355e5ea2b7d1c127135163412af8d8baefce 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { panic("unexpected Delete call") } +func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { panic("unexpected List call") } @@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa panic("unexpected ListWithFilters call") } +func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { panic("unexpected UpdateBalance call") } @@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go index 46482af315649d3339071d527a9dedfcb73199a4..ae53037e67e2ba06ac6dae58e6011d76457c5c5a 100644 --- a/backend/internal/server/middleware/backend_mode_guard.go +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun } } +func backendModeAllowsAuthPath(path string) bool { + path = strings.ToLower(strings.TrimSpace(path)) + for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} { + if strings.HasSuffix(path, suffix) { + return true + } + } + + for _, suffix := range []string{ + "/auth/oauth/linuxdo/callback", + "/auth/oauth/wechat/callback", + "/auth/oauth/wechat/payment/callback", + "/auth/oauth/oidc/callback", + "/auth/oauth/linuxdo/complete-registration", + "/auth/oauth/wechat/complete-registration", + "/auth/oauth/oidc/complete-registration", + "/auth/oauth/linuxdo/create-account", + "/auth/oauth/wechat/create-account", + "/auth/oauth/oidc/create-account", + "/auth/oauth/linuxdo/bind-login", + "/auth/oauth/wechat/bind-login", + "/auth/oauth/oidc/bind-login", + } { + if strings.HasSuffix(path, suffix) { + return true + } + } + + return strings.Contains(path, "/auth/oauth/pending/") +} + // BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. -// Allows: login, login/2fa, logout, refresh (admin needs these). -// Blocks: register, forgot-password, reset-password, OAuth, etc. +// Allows the minimal auth surface admins still need in backend mode, including +// OAuth callbacks and pending continuations. Handler-level backend mode checks +// still enforce admin-only login and forbid self-service registration. func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { return func(c *gin.Context) { if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { c.Next() return } - path := c.Request.URL.Path - // Allow login, 2FA, logout, refresh, public settings - allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} - for _, suffix := range allowedSuffixes { - if strings.HasSuffix(path, suffix) { - c.Next() - return - } + if backendModeAllowsAuthPath(c.Request.URL.Path) { + c.Next() + return } response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") c.Abort() diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go index 8878ebc92d956e5c830f15edf43d8aba19a3f3d2..bd77677b74d330d6fc0f4d0427ccf7f69ff82ddb 100644 --- a/backend/internal/server/middleware/backend_mode_guard_test.go +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) { path: "/api/v1/auth/refresh", wantStatus: http.StatusOK, }, + { + name: "enabled_blocks_linuxdo_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_linuxdo_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_payment_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_payment_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_oidc_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_oidc_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_exchange", + enabled: "true", + path: "/api/v1/auth/oauth/pending/exchange", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_send_verify_code", + enabled: "true", + path: "/api/v1/auth/oauth/pending/send-verify-code", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/pending/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/pending/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_legacy_complete_registration", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/complete-registration", + wantStatus: http.StatusOK, + }, { name: "enabled_blocks_register", enabled: "true", diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 4aceb3550258b23efdaad7aced6ebd95b47d32ea..48cb9004bfb918f754bee37f04491017882b1cc8 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "errors" "strings" @@ -11,11 +12,19 @@ import ( // NewJWTAuthMiddleware 创建 JWT 认证中间件 func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware { - return JWTAuthMiddleware(jwtAuth(authService, userService)) + return JWTAuthMiddleware(jwtAuth(authService, userService, userService)) +} + +type jwtUserReader interface { + GetByID(ctx context.Context, id int64) (*service.User, error) +} + +type userActivityToucher interface { + TouchLastActiveForUser(ctx context.Context, user *service.User) } // jwtAuth JWT认证中间件实现 -func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc { +func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc { return func(c *gin.Context) { // 从Authorization header中提取token authHeader := c.GetHeader("Authorization") @@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) Concurrency: user.Concurrency, }) c.Set(string(ContextKeyUserRole), user.Role) + if activityToucher != nil { + activityToucher.TouchLastActiveForUser(c.Request.Context(), user) + } c.Next() } diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index c483a51eafca6b83e5867277ae4c227046494e07..84fd696739dfa61332e4d12c0dd14611f77b2b11 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" @@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e return u, nil } +func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error { + return nil +} + +type recordingActivityToucher struct { + userIDs []int64 +} + +func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) { + if user == nil { + return + } + r.userIDs = append(r.userIDs, user.ID) +} + // newJWTTestEnv 创建 JWT 认证中间件测试环境。 // 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { @@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil, nil) + toucher := &recordingActivityToucher{} + + r := gin.New() + r.Use(jwtAuth(authSvc, userSvc, toucher)) + r.GET("/protected", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, []int64{1}, toucher.userIDs) +} + func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { router, _ := newJWTTestEnv(nil) diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 7021ab2e23c8b78f8bb1b0f99e074c77e1cec80d..398c0351dc0814981e6bf511666d17dbfb815e90 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -96,7 +96,8 @@ func isAPIRoutePath(c *gin.Context) bool { return strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/responses") + strings.HasPrefix(path, "/responses") || + strings.HasPrefix(path, "/images") } // enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9af0fd8ef6ab886f66f4b4fed2ea42099f01db2d..70160f7eecedea3fd3a3f12d0eff8aa3d263d9fe 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -88,6 +88,9 @@ func RegisterAdminRoutes( // 渠道管理 registerChannelRoutes(admin, h) + + // 渠道监控 + registerChannelMonitorRoutes(admin, h) } } @@ -212,6 +215,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) + users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) users.POST("", h.Admin.User.Create) users.PUT("/:id", h.Admin.User.Update) users.DELETE("/:id", h.Admin.User.Delete) @@ -220,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) + users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -243,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) + groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides) + groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) } } @@ -563,3 +570,27 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { channels.DELETE("/:id", h.Admin.Channel.Delete) } } + +func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + monitors := admin.Group("/channel-monitors") + { + monitors.GET("", h.Admin.ChannelMonitor.List) + monitors.POST("", h.Admin.ChannelMonitor.Create) + monitors.GET("/:id", h.Admin.ChannelMonitor.Get) + monitors.PUT("/:id", h.Admin.ChannelMonitor.Update) + monitors.DELETE("/:id", h.Admin.ChannelMonitor.Delete) + monitors.POST("/:id/run", h.Admin.ChannelMonitor.Run) + monitors.GET("/:id/history", h.Admin.ChannelMonitor.History) + } + + templates := admin.Group("/channel-monitor-templates") + { + templates.GET("", h.Admin.ChannelMonitorTemplate.List) + templates.POST("", h.Admin.ChannelMonitorTemplate.Create) + templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get) + templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update) + templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete) + templates.GET("/:id/monitors", h.Admin.ChannelMonitorTemplate.AssociatedMonitors) + templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply) + } +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c143b030fc88da423867e4bf77404474d9039f78..642a2103e328c4a078d49f83c31669d0dc4d1336 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -63,14 +63,90 @@ func RegisterAuthRoutes( FailureMode: middleware.RateLimitFailClose, }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.LinuxDoOAuthStart(c) + }) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.WeChatOAuthStart(c) + }) + auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) + auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart) + auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback) + auth.POST("/oauth/pending/exchange", + rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.ExchangePendingOAuthCompletion, + ) + auth.POST("/oauth/pending/send-verify-code", + rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.SendPendingOAuthVerifyCode, + ) + auth.POST("/oauth/pending/create-account", + rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreatePendingOAuthAccount, + ) + auth.POST("/oauth/pending/bind-login", + rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindPendingOAuthLogin, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/linuxdo/bind-login", + rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindLinuxDoOAuthLogin, + ) + auth.POST("/oauth/linuxdo/create-account", + rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateLinuxDoOAuthAccount, + ) + auth.POST("/oauth/wechat/complete-registration", + rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteWeChatOAuthRegistration, + ) + auth.POST("/oauth/wechat/bind-login", + rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindWeChatOAuthLogin, + ) + auth.POST("/oauth/wechat/create-account", + rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateWeChatOAuthAccount, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) + auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.OIDCOAuthStart(c) + }) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{ @@ -78,6 +154,18 @@ func RegisterAuthRoutes( }), h.Auth.CompleteOIDCOAuthRegistration, ) + auth.POST("/oauth/oidc/bind-login", + rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindOIDCOAuthLogin, + ) + auth.POST("/oauth/oidc/create-account", + rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateOIDCOAuthAccount, + ) } // 公开设置(无需认证) @@ -94,5 +182,6 @@ func RegisterAuthRoutes( authenticated.GET("/auth/me", h.Auth.GetCurrentUser) // 撤销所有会话(需要认证) authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) + authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie) } } diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go index 4f411cec570a3595248d7d2b86f440f6b0bdf119..07a66efb4e580cbade54a8befa4cf00b63d06783 100644 --- a/backend/internal/server/routes/auth_rate_limit_test.go +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { "/api/v1/auth/login", "/api/v1/auth/login/2fa", "/api/v1/auth/send-verify-code", + "/api/v1/auth/oauth/pending/send-verify-code", } for _, path := range paths { diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index cbf9829326006df1e0adb8658a93dbaded5348b0..9541cda1abf6a4cda07b4680c0c2df3e257c1280 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -88,6 +88,30 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + gateway.POST("/images/generations", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) + gateway.POST("/images/edits", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -116,6 +140,13 @@ func RegisterGatewayRoutes( r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + codexDirect := r.Group("/backend-api/codex") + codexDirect.Use(bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic) + { + codexDirect.POST("/responses", responsesHandler) + codexDirect.POST("/responses/*subpath", responsesHandler) + codexDirect.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) + } // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { if getGroupPlatform(c) == service.PlatformOpenAI { @@ -124,6 +155,30 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) + r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go index 4d65a626f9f96131a84e78143e448105d64b1ae1..19ef568600c2aab76144a6e447d771483feecdfe 100644 --- a/backend/internal/server/routes/gateway_test.go +++ b/backend/internal/server/routes/gateway_test.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -24,6 +25,11 @@ func newGatewayRoutesTestRouter() *gin.Engine { OpenAIGateway: &handler.OpenAIGatewayHandler{}, }, servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { + groupID := int64(1) + c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{ + GroupID: &groupID, + Group: &service.Group{Platform: service.PlatformOpenAI}, + }) c.Next() }), nil, @@ -39,7 +45,12 @@ func newGatewayRoutesTestRouter() *gin.Engine { func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { router := newGatewayRoutesTestRouter() - for _, path := range []string{"/v1/responses/compact", "/responses/compact"} { + for _, path := range []string{ + "/v1/responses/compact", + "/responses/compact", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", + } { req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -48,3 +59,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path) } } + +func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) { + router := newGatewayRoutesTestRouter() + + for _, path := range []string{ + "/v1/images/generations", + "/v1/images/edits", + "/images/generations", + "/images/edits", + } { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path) + } +} diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 23bd58ad8eb20335e6169cc3cc89e4c445d15455..e4828eadf95b8a518c6c71b2b94ee6f4d3508424 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -44,11 +44,13 @@ func RegisterPaymentRoutes( } // --- Public payment endpoints (no auth) --- - // Payment result page needs to verify order status without login - // (user session may have expired during provider redirect). + // Signed resume-token recovery is the preferred public lookup path. + // The legacy anonymous out_trade_no verify endpoint remains available as a + // persisted-state compatibility path for staggered upgrades. public := v1.Group("/payment/public") { public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) + public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken) } // --- Webhook endpoints (no auth) --- diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index d004f8b4391d02c28a40e8d8f0f4ae50f881be4c..babab125340e28f938efef30642682c933a9f8e5 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,10 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode) + user.POST("/account-bindings/email", h.User.BindEmailIdentity) + user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) + user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") @@ -64,6 +68,12 @@ func RegisterUserRoutes( groups.GET("/rates", h.APIKey.GetUserGroupRates) } + // 用户可用渠道(非管理员接口) + channels := authenticated.Group("/channels") + { + channels.GET("/available", h.AvailableChannel.List) + } + // 使用记录 usage := authenticated.Group("/usage") { @@ -99,5 +109,12 @@ func RegisterUserRoutes( subscriptions.GET("/progress", h.Subscription.GetProgress) subscriptions.GET("/summary", h.Subscription.GetSummary) } + + // 渠道监控(用户只读) + monitors := authenticated.Group("/channel-monitors") + { + monitors.GET("", h.ChannelMonitor.List) + monitors.GET("/:id/status", h.ChannelMonitor.GetStatus) + } } } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 52db30738a4f644ed0e2025a57e81ddaa3c7ed2b..0fb6e18f526d9c955575d307d06ec5a1bfaee8c8 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool { if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) { return false } + if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() { + return false + } return true } @@ -908,6 +911,32 @@ func (a *Account) GetChatGPTAccountID() string { return a.GetCredential("chatgpt_account_id") } +func (a *Account) GetOpenAIDeviceID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return strings.TrimSpace(a.GetExtraString("openai_device_id")) +} + +func (a *Account) GetOpenAISessionID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return strings.TrimSpace(a.GetExtraString("openai_session_id")) +} + +func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool { + if !a.IsOpenAI() { + return false + } + switch capability { + case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative: + return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey + default: + return true + } +} + func (a *Account) GetChatGPTUserID() string { if !a.IsOpenAIOAuth() { return "" diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2895b34c889e98c2af60cfab0ca619adc0d1d32a --- /dev/null +++ b/backend/internal/service/account_quota_schedulable_test.go @@ -0,0 +1,123 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + account *Account + want bool + }{ + { + name: "apikey daily quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 10.0, + "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + }, + want: false, + }, + { + name: "apikey weekly quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_weekly_limit": 50.0, + "quota_weekly_used": 50.0, + "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339), + }, + }, + want: false, + }, + { + name: "apikey total quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_limit": 100.0, + "quota_used": 100.0, + }, + }, + want: false, + }, + { + name: "apikey quota not exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 5.0, + "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + }, + want: true, + }, + { + name: "apikey expired daily period restores schedulable", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 10.0, + "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339), + }, + }, + want: true, + }, + { + name: "oauth ignores quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 10.0, + "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + }, + want: true, + }, + { + name: "bedrock quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeBedrock, + Extra: map[string]any{ + "quota_limit": 200.0, + "quota_used": 200.0, + }, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.account.IsSchedulable()) + }) + } +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index a5559b7de2312f4b120fc7458217a6a1ce41ae69..e5bc93ca0893d2eeaaf690be55e640f6d94d8cfd 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -52,8 +52,14 @@ type TestEvent struct { const ( defaultGeminiTextTestPrompt = "hi" defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." + defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." ) +// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2). +func isOpenAIImageModel(model string) bool { + return strings.HasPrefix(strings.ToLower(model), "gpt-image-") +} + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -170,7 +176,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int // Route to platform-specific test method if account.IsOpenAI() { - return s.testOpenAIAccountConnection(c, account, modelID) + return s.testOpenAIAccountConnection(c, account, modelID, prompt) } if account.IsGemini() { @@ -410,8 +416,9 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co } // testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() + _ = prompt // Default to openai.DefaultTestModel for OpenAI testing testModelID := modelID @@ -429,6 +436,18 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } } + // Route to image generation test if an image model is selected + if isOpenAIImageModel(testModelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultOpenAIImageTestPrompt + } + if account.Type == "apikey" { + return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt) + } + return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt) + } + // Determine authentication method and API URL var authToken string var apiURL string @@ -1025,7 +1044,198 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) } } -// sendEvent sends a SSE event to the client +// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account. +func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIApiKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations" + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + + payload := map[string]any{ + "model": modelID, + "prompt": prompt, + "n": 1, + "response_format": "b64_json", + } + payloadBytes, _ := json.Marshal(payload) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error())) + } + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]} + var result struct { + Data []struct { + B64JSON string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` + } `json:"data"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + if len(result.Data) == 0 { + return s.sendErrorAndEnd(c, "No images returned from API") + } + + for _, item := range result.Data { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) + } + if item.B64JSON != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:image/png;base64," + item.B64JSON, + MimeType: "image/png", + }) + } + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API. +func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIAccessToken() + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"}) + + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: strings.TrimSpace(modelID), + Prompt: prompt, + } + applyOpenAIImagesDefaults(parsed) + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error())) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Host = "chatgpt.com" + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", "opencode") + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + req.Header.Set("User-Agent", customUA) + } else { + req.Header.Set("User-Agent", codexCLIUserAgent) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error())) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + message := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if message == "" { + message = fmt.Sprintf("Responses API returned %d", resp.StatusCode) + } + return s.sendErrorAndEnd(c, message) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error())) + } + + results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error())) + } + if len(results) == 0 { + return s.sendErrorAndEnd(c, "No images returned from responses API") + } + + for _, item := range results { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) + } + mimeType := openAIImageOutputMIMEType(item.OutputFormat) + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:" + mimeType + ";base64," + item.Result, + MimeType: mimeType, + }) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go new file mode 100644 index 0000000000000000000000000000000000000000..80a2fc31553425d56f228f9a062ab88d43dc8838 --- /dev/null +++ b/backend/internal/service/account_test_service_openai_image_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 53, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat") + require.NoError(t, err) + require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool") + require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") + require.Contains(t, rec.Body.String(), "\"success\":true") +} diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 8260697991fd02995066ffaf648eef2b0ffec30a..82ff0a8b700e4c6a895aff464df46e55289ea431 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -103,7 +103,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.NoError(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) @@ -134,7 +134,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c6254ff95818788cef884c82253f70a2b..434f1f38e6d4e6e71882a378cc831778c5a36038 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,15 +2,19 @@ package service import ( "context" + "encoding/json" "errors" "fmt" "io" "log/slog" "net/http" + "sort" "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -29,10 +33,12 @@ type AdminService interface { UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) + BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -46,6 +52,8 @@ type AdminService interface { GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // API Key management (admin) @@ -110,6 +118,7 @@ type CreateUserInput struct { Notes string Balance float64 Concurrency int + RPMLimit int AllowedGroups []int64 } @@ -120,6 +129,7 @@ type UpdateUserInput struct { Notes *string Balance *float64 // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0" + RPMLimit *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 @@ -127,6 +137,44 @@ type UpdateUserInput struct { GroupRates map[int64]*float64 } +type AdminBindAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Issuer *string + Metadata map[string]any + Channel *AdminBindAuthIdentityChannelInput +} + +type AdminBindAuthIdentityChannelInput struct { + Channel string + ChannelAppID string + ChannelSubject string + Metadata map[string]any +} + +type AdminBoundAuthIdentity struct { + UserID int64 `json:"user_id"` + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"` +} + +type AdminBoundAuthIdentityChannel struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateGroupInput struct { Name string Description string @@ -157,6 +205,8 @@ type CreateGroupInput struct { RequireOAuthOnly bool RequirePrivacySet bool MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制) + RPMLimit int // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -192,6 +242,8 @@ type UpdateGroupInput struct { RequireOAuthOnly *bool RequirePrivacySet *bool MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 + RPMLimit *int // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -275,6 +327,22 @@ type ReplaceUserGroupResult struct { MigratedKeys int64 // 迁移的 Key 数量 } +// UserRPMStatus describes a user's current per-minute RPM usage. +type UserRPMStatus struct { + UserRPMUsed int `json:"user_rpm_used"` + UserRPMLimit int `json:"user_rpm_limit"` + PerGroup []UserGroupRPMStatus `json:"per_group"` +} + +// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair. +type UserGroupRPMStatus struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Used int `json:"used"` + Limit int `json:"limit"` + Source string `json:"source"` // "group" | "override" +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -421,6 +489,8 @@ const ( proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" ) +var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available") + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository @@ -430,6 +500,7 @@ type adminServiceImpl struct { apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository userGroupRateRepo UserGroupRateRepository + userRPMCache UserRPMCache billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -454,6 +525,7 @@ func NewAdminService( apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, userGroupRateRepo UserGroupRateRepository, + userRPMCache UserRPMCache, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -472,6 +544,7 @@ func NewAdminService( apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, userGroupRateRepo: userGroupRateRepo, + userRPMCache: userRPMCache, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -491,6 +564,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi if err != nil { return nil, 0, err } + if len(users) > 0 { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr) + } else { + for i := range users { + users[i].LastUsedAt = lastUsedByUserID[users[i].ID] + } + } + } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { @@ -535,6 +622,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) if err != nil { return nil, err } + lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr) + } else { + user.LastUsedAt = lastUsedAt + } // 加载用户专属分组倍率 if s.userGroupRateRepo != nil { rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) @@ -555,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu Role: RoleUser, // Always create as regular user, never admin Balance: input.Balance, Concurrency: input.Concurrency, + RPMLimit: input.RPMLimit, Status: StatusActive, AllowedGroups: input.AllowedGroups, } @@ -586,6 +680,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -599,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda oldConcurrency := user.Concurrency oldStatus := user.Status oldRole := user.Role + oldRPMLimit := user.RPMLimit if input.Email != "" { user.Email = input.Email @@ -624,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.Concurrency = *input.Concurrency } + if input.RPMLimit != nil { + user.RPMLimit = *input.RPMLimit + } + if input.AllowedGroups != nil { user.AllowedGroups = *input.AllowedGroups } @@ -640,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } if s.authCacheInvalidator != nil { - if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联, + // 不失效缓存会让修改在一个 L2 TTL 内失去效果。 + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) } } @@ -762,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag return keys, result.Total, nil } +func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) { + if s.userRPMCache == nil { + return nil, ErrRPMStatusUnavailable + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err) + } + + keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "") + if err != nil { + return nil, err + } + + groupIDSet := make(map[int64]struct{}) + for _, key := range keys { + if key.GroupID != nil && *key.GroupID > 0 { + groupIDSet[*key.GroupID] = struct{}{} + } + } + + groupIDs := make([]int64, 0, len(groupIDSet)) + for groupID := range groupIDSet { + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) + + var perGroup []UserGroupRPMStatus + for _, groupID := range groupIDs { + used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID) + if getErr != nil { + logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr) + } + + entry := UserGroupRPMStatus{ + GroupID: groupID, + Used: used, + } + + if s.groupRepo != nil { + if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil { + entry.GroupName = group.Name + entry.Limit = group.RPMLimit + entry.Source = "group" + } else if groupErr != nil { + logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr) + } + } + + if s.userGroupRateRepo != nil { + override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID) + if overrideErr != nil { + logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr) + } else if override != nil { + entry.Limit = *override + entry.Source = "override" + } + } + + perGroup = append(perGroup, entry) + } + + return &UserRPMStatus{ + UserRPMUsed: userRPMUsed, + UserRPMLimit: user.RPMLimit, + PerGroup: perGroup, + }, nil +} + func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { // Return mock data for now return map[string]any{ @@ -788,6 +973,334 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int return codes, result.Total, totalRecharged, nil } +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey) + compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey) + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identityRecords, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(compatibleProviderKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + identity := selectOwnedAdminAuthIdentity(identityRecords, userID) + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(canonicalProviderKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID). + SetVerifiedAt(verifiedAt). + SetProviderKey(canonicalProviderKey) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channelRecords, err := tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(compatibleProviderKeys...), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID) + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(canonicalProviderKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID). + SetProviderKey(canonicalProviderKey) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} + +func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return []string{providerKey} + } + if providerType != "wechat" { + return []string{providerKey} + } + + keys := []string{providerKey} + if !strings.EqualFold(providerKey, "wechat-main") { + keys = append(keys, "wechat-main") + } + if !strings.EqualFold(providerKey, "wechat") { + keys = append(keys, "wechat") + } + return keys +} + +func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + existingKey = strings.TrimSpace(existingKey) + requestedKey = strings.TrimSpace(requestedKey) + if providerType != "wechat" { + if requestedKey != "" { + return requestedKey + } + return existingKey + } + if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") { + return "wechat-main" + } + if requestedKey != "" { + return requestedKey + } + return existingKey +} + +func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int { + providerType = strings.TrimSpace(strings.ToLower(providerType)) + providerKey = strings.TrimSpace(providerKey) + if providerType != "wechat" { + return 0 + } + switch { + case strings.EqualFold(providerKey, "wechat-main"): + return 0 + case strings.EqualFold(providerKey, "wechat"): + return 2 + default: + return 1 + } +} + +func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { + var selected *dbent.AuthIdentity + for _, record := range records { + if record.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool { + for _, record := range records { + if record.UserID != userID { + return true + } + } + return false +} + +func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { + var selected *dbent.AuthIdentityChannel + for _, record := range records { + if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID { + continue + } + if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) { + selected = record + } + } + return selected +} + +func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { + for _, record := range records { + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return true + } + } + return false +} + +func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { + if input == nil { + return nil + } + channel := &AdminBindAuthIdentityChannelInput{ + Channel: strings.TrimSpace(input.Channel), + ChannelAppID: strings.TrimSpace(input.ChannelAppID), + ChannelSubject: strings.TrimSpace(input.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(input.Metadata), + } + if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" { + return nil + } + return channel +} + +func normalizeAdminAuthIdentityProviderType(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "email": + return "email" + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + default: + return "" + } +} + +func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity { + if identity == nil { + return nil + } + result := &AdminBoundAuthIdentity{ + UserID: identity.UserID, + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + } + if channel != nil { + result.Channel = &AdminBoundAuthIdentityChannel{ + Channel: strings.TrimSpace(channel.Channel), + ChannelAppID: strings.TrimSpace(channel.ChannelAppID), + ChannelSubject: strings.TrimSpace(channel.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata), + CreatedAt: channel.CreatedAt, + UpdatedAt: channel.UpdatedAt, + } + } + return result +} + +func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { + if input == nil { + return nil + } + if len(input) == 0 { + return map[string]any{} + } + data, err := json.Marshal(input) + if err != nil { + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + out = make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + } + return out +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} @@ -811,6 +1324,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + platform := input.Platform if platform == "" { platform = PlatformAnthropic @@ -911,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { @@ -1050,6 +1568,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.Platform = input.Platform } if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { @@ -1142,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MessagesDispatchModelConfig != nil { group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) } + if input.RPMLimit != nil { + group.RPMLimit = *input.RPMLimit + } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) if len(input.CopyAccountsFromGroupIDs) > 0 { // 去重源分组 IDs @@ -1216,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd } } - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) - } return group, nil } @@ -1286,9 +1811,47 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro if s.userGroupRateRepo == nil { return nil } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } +func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if s.userGroupRateRepo == nil { + return nil + } + if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + +func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + if s.userGroupRateRepo == nil { + return nil + } + for _, e := range entries { + if e.RPMOverride != nil && *e.RPMOverride < 0 { + return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID)) + } + } + if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 419ddbc329642e9717d545319c6f53afd297347c..fcde5cbf4abd89d7e666c378ad7bee8666c64146 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } @@ -70,6 +79,23 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected") +} + +func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected") +} + +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go new file mode 100644 index 0000000000000000000000000000000000000000..719199f251a2b8710507bd5e284ef3926bb3f582 --- /dev/null +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -0,0 +1,302 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-123", + Metadata: map[string]any{"scene": "migration"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "wechat", result.ProviderType) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.VerifiedAt) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + require.NotNil(t, identity.VerifiedAt) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "migration", channel.Metadata["scene"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + target, err := client.User.Create(). + SetEmail("target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("subject-1"). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }) + require.Error(t, err) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err)) +} + +func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("same-user@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "first"}, + }) + require.NoError(t, err) + + second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "second"}, + }) + require.NoError(t, err) + require.Equal(t, first.UserID, second.UserID) + require.Equal(t, "second", second.Metadata["source"]) + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("subject-2"), + ). + All(ctx) + require.NoError(t, err) + require.Len(t, identities, 1) + require.Equal(t, "second", identities[0].Metadata["source"]) +} + +func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("wechat-alias@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetProviderSubject("union-legacy-123"). + SetMetadata(map[string]any{"source": "legacy"}). + Save(ctx) + require.NoError(t, err) + + legacyChannel, err := client.AuthIdentityChannel.Create(). + SetIdentityID(legacyIdentity.ID). + SetProviderType("wechat"). + SetProviderKey("wechat"). + SetChannel("open"). + SetChannelAppID("wx-open"). + SetChannelSubject("openid-legacy-123"). + SetMetadata(map[string]any{"scene": "legacy"}). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-legacy-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-legacy-123", + Metadata: map[string]any{"scene": "admin-repair"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", identity.ProviderKey) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + + channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID) + require.NoError(t, err) + require.Equal(t, "wechat-main", channel.ProviderKey) + require.Equal(t, legacyIdentity.ID, channel.IdentityID) + require.Equal(t, "admin-repair", channel.Metadata["scene"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderSubjectEQ("union-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + channelCount, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-legacy-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, channelCount) +} + +func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("invalid-provider@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "github", + ProviderKey: "github-main", + ProviderSubject: "subject-3", + }) + require.Error(t, err) + require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fbc856cf3c1f9541141047b6eea0cec211080edb..fe9e7701a26c1916aa7db35b7d2b732552eac422 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -13,15 +13,18 @@ import ( ) type userRepoStub struct { - user *User - getErr error - createErr error - deleteErr error - exists bool - existsErr error - nextID int64 - created []*User - deletedIDs []int64 + user *User + getErr error + createErr error + deleteErr error + exists bool + existsErr error + nextID int64 + created []*User + updated []*User + deletedIDs []int64 + usersByEmail map[string]*User + getByEmailErr error } func (s *userRepoStub) Create(ctx context.Context, user *User) error { @@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error { user.ID = s.nextID } s.created = append(s.created, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user return nil } @@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { } func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { - panic("unexpected GetByEmail call") + if s.getByEmailErr != nil { + return nil, s.getByEmailErr + } + if s.usersByEmail != nil { + if user, ok := s.usersByEmail[email]; ok { + return user, nil + } + } + if s.user != nil && s.user.Email == email { + return s.user, nil + } + return nil, ErrUserNotFound } func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { @@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { } func (s *userRepoStub) Update(ctx context.Context, user *User) error { - panic("unexpected Update call") + s.updated = append(s.updated, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user + return nil } func (s *userRepoStub) Delete(ctx context.Context, id int64) error { @@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { return s.deleteErr } +func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected List call") } @@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa panic("unexpected ListWithFilters call") } +func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserIDs call") +} + +func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + panic("unexpected GetLatestUsedAtByUserID call") +} + +func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + panic("unexpected UpdateUserLastActiveAt call") +} + func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error { panic("unexpected UpdateBalance call") } @@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + +func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + panic("unexpected UnbindUserAuthProvider call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2232c9c38b6e9994233e569b6befc976c8a2fd63 --- /dev/null +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -0,0 +1,187 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type ensureEmailCall struct { + userID int64 + email string +} + +type replaceEmailCall struct { + userID int64 + oldEmail string + newEmail string +} + +type emailSyncRepoStub struct { + user *User + nextID int64 + updateCalls int + created []*User + updated []*User + ensureCalls []ensureEmailCall + replaceCalls []replaceEmailCall + ensureErr error + replaceErr error +} + +func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error { + if s.nextID != 0 && user.ID == 0 { + user.ID = s.nextID + } + s.created = append(s.created, user) + s.user = user + return nil +} + +func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) { + if s.user == nil { + return nil, ErrUserNotFound + } + cloned := *s.user + return &cloned, nil +} + +func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) { + return nil, ErrUserNotFound +} + +func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) { + return nil, fmt.Errorf("unexpected GetFirstAdmin call") +} + +func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error { + s.updateCalls++ + s.updated = append(s.updated, user) + s.user = user + return nil +} + +func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + return nil, fmt.Errorf("unexpected GetUserAvatar call") +} + +func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + return nil, fmt.Errorf("unexpected UpsertUserAvatar call") +} + +func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error { + return fmt.Errorf("unexpected DeleteUserAvatar call") +} + +func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, fmt.Errorf("unexpected List call") +} + +func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, fmt.Errorf("unexpected ListWithFilters call") +} + +func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + return nil +} + +func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } + +func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } + +func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } + +func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } + +func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } + +func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil } + +func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } + +func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error { + s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email}) + return s.ensureErr +} + +func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error { + s.replaceCalls = append(s.replaceCalls, replaceEmailCall{ + userID: userID, + oldEmail: oldEmail, + newEmail: newEmail, + }) + return s.replaceErr +} + +func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + nextID: 55, + ensureErr: fmt.Errorf("unexpected email resync"), + } + svc := &adminServiceImpl{userRepo: repo} + + user, err := svc.CreateUser(context.Background(), &CreateUserInput{ + Email: "admin-created@example.com", + Password: "strong-pass", + }) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(55), user.ID) + require.Empty(t, repo.ensureCalls) + require.Empty(t, repo.replaceCalls) +} + +func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + user: &User{ + ID: 91, + Email: "before@example.com", + Role: RoleUser, + Status: StatusActive, + Concurrency: 3, + }, + replaceErr: fmt.Errorf("unexpected email resync"), + } + svc := &adminServiceImpl{userRepo: repo} + + updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{ + Email: "after@example.com", + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, "after@example.com", updated.Email) + require.Empty(t, repo.replaceCalls) + require.Empty(t, repo.ensureCalls) +} diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go index 77635247d8baeef1ffe8a0dcd49df5a5c5288724..d2efb6441949c13ff59d63a717b14a37e8bc3ed7 100644 --- a/backend/internal/service/admin_service_group_rate_test.go +++ b/backend/internal/service/admin_service_group_rate_test.go @@ -5,8 +5,10 @@ package service import ( "context" "errors" + "net/http" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct { syncedGroupID int64 syncedEntries []GroupRateMultiplierInput syncGroupErr error + + rpmSyncedGroupID int64 + rpmSyncedEntries []GroupRPMOverrideInput + rpmSyncErr error } func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { @@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { if s.getByGroupIDErr != nil { return nil, s.getByGroupIDErr @@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C return s.syncGroupErr } +func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + s.rpmSyncedGroupID = groupID + s.rpmSyncedEntries = entries + return s.rpmSyncErr +} + +func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) return s.deleteByGroupErr @@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { repo := &userGroupRateRepoStubForGroupRate{ getByGroupIDData: map[int64][]UserGroupRateEntry{ 10: { - {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, - {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, + {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)}, + {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)}, }, }, } @@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { require.Len(t, entries, 2) require.Equal(t, int64(1), entries[0].UserID) require.Equal(t, "alice", entries[0].UserName) - require.Equal(t, 1.5, entries[0].RateMultiplier) + require.NotNil(t, entries[0].RateMultiplier) + require.Equal(t, 1.5, *entries[0].RateMultiplier) require.Equal(t, int64(2), entries[1].UserID) - require.Equal(t, 0.8, entries[1].RateMultiplier) + require.NotNil(t, entries[1].RateMultiplier) + require.Equal(t, 0.8, *entries[1].RateMultiplier) }) t.Run("returns nil when repo is nil", func(t *testing.T) { @@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { require.Contains(t, err.Error(), "sync failed") }) } + +func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) { + t.Run("syncs entries to repo", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + override := 20 + entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}} + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries) + require.NoError(t, err) + require.Equal(t, int64(10), repo.rpmSyncedGroupID) + require.Equal(t, entries, repo.rpmSyncedEntries) + }) + + t.Run("rejects negative override as bad request", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + negative := -1 + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{ + {UserID: 2, RPMOverride: &negative}, + }) + require.Error(t, err) + require.Equal(t, http.StatusBadRequest, infraerrors.Code(err)) + require.Zero(t, repo.rpmSyncedGroupID) + }) +} diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index a4c6d0caba6ef7dc4b4594d20adaf4c13b315a1b..eef022406997a4cb79dd2caebabbfc715390c373 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformAnthropic, + Status: StatusActive, + RPMLimit: 10, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + groupRepo: repo, + authCacheInvalidator: invalidator, + } + + rpmLimit := 60 + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + RPMLimit: &rpmLimit, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.Equal(t, 60, repo.updated.RPMLimit) + require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存") +} + func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { repo := &groupRepoStubForAdmin{} svc := &adminServiceImpl{groupRepo: repo} @@ -621,6 +646,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformOpenAI, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -641,6 +667,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeSubscription, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -695,6 +722,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -713,6 +741,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -733,6 +762,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAntigravity, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -750,6 +780,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &zero, }) diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index ceeb52c2944c830d82a071fecae79f4a00719ad2..ff3f65a89d92bbe153f483538f88d52484536820 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" @@ -16,6 +17,8 @@ type userRepoStubForListUsers struct { users []User err error listWithFiltersParams pagination.PaginationParams + lastUsedByUserID map[int64]*time.Time + lastUsedErr error } func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { @@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag }, nil } +func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) { + if s.lastUsedErr != nil { + return nil, s.lastUsedErr + } + result := make(map[int64]*time.Time, len(userIDs)) + for _, userID := range userIDs { + if ts, ok := s.lastUsedByUserID[userID]; ok { + result[userID] = ts + } + } + return result, nil +} + +func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) { + if s.lastUsedErr != nil { + return nil, s.lastUsedErr + } + return s.lastUsedByUserID[userID], nil +} + type userGroupRateRepoStubForListUsers struct { batchCalls int singleCall []int64 @@ -66,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { panic("unexpected SyncUserGroupRates call") } @@ -78,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C panic("unexpected SyncGroupRateMultipliers call") } +func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error { + panic("unexpected SyncGroupRPMOverrides call") +} + +func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { panic("unexpected DeleteByGroupID call") } @@ -130,3 +165,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) { SortOrder: "ASC", }, userRepo.listWithFiltersParams) } + +func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) { + lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second) + userRepo := &userRepoStubForListUsers{ + users: []User{{ID: 101, Email: "u@example.com"}}, + lastUsedByUserID: map[int64]*time.Time{ + 101: &lastUsed, + }, + } + svc := &adminServiceImpl{userRepo: userRepo} + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "") + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, users, 1) + require.NotNil(t, users[0].LastUsedAt) + require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second) +} diff --git a/backend/internal/service/admin_service_rpm_status_test.go b/backend/internal/service/admin_service_rpm_status_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c298f69b6adc880a98db3ad39f55b73e5726ae52 --- /dev/null +++ b/backend/internal/service/admin_service_rpm_status_test.go @@ -0,0 +1,112 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type rpmStatusUserRepoStub struct { + UserRepository + user *User +} + +func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) { + return s.user, nil +} + +type rpmStatusAPIKeyRepoStub struct { + APIKeyRepository + keys []APIKey +} + +func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil +} + +type rpmStatusGroupRepoStub struct { + GroupRepository + groups map[int64]*Group +} + +func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) { + return s.groups[id], nil +} + +type rpmStatusRateRepoStub struct { + UserGroupRateRepository + overrides map[int64]*int +} + +func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) { + return s.overrides[groupID], nil +} + +type rpmStatusCacheStub struct { + UserRPMCache + userUsed int + groupUsed map[int64]int +} + +func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) { + return s.groupUsed[groupID], nil +} + +func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) { + return s.userUsed, nil +} + +func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) { + groupOneID := int64(1) + groupTwoID := int64(2) + override := 7 + svc := &adminServiceImpl{ + userRepo: &rpmStatusUserRepoStub{user: &User{ + ID: 42, + RPMLimit: 20, + }}, + apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{ + {ID: 100, UserID: 42, GroupID: &groupTwoID}, + {ID: 101, UserID: 42, GroupID: &groupOneID}, + {ID: 102, UserID: 42, GroupID: &groupTwoID}, + {ID: 103, UserID: 42}, + }}, + groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{ + groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10}, + groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60}, + }}, + userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{ + groupTwoID: &override, + }}, + userRPMCache: &rpmStatusCacheStub{ + userUsed: 5, + groupUsed: map[int64]int{ + groupOneID: 3, + groupTwoID: 4, + }, + }, + } + + status, err := svc.GetUserRPMStatus(context.Background(), 42) + require.NoError(t, err) + require.Equal(t, &UserRPMStatus{ + UserRPMUsed: 5, + UserRPMLimit: 20, + PerGroup: []UserGroupRPMStatus{ + {GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"}, + {GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"}, + }, + }, status) +} diff --git a/backend/internal/service/admin_service_update_user_rpm_test.go b/backend/internal/service/admin_service_update_user_rpm_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cb4c3986813371b02a68714362aa9731ba0ef4af --- /dev/null +++ b/backend/internal/service/admin_service_update_user_rpm_test.go @@ -0,0 +1,69 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构, +// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。 +type rpmUserRepoStub struct { + *userRepoStub + lastUpdated *User +} + +func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error { + if user == nil { + return nil + } + clone := *user + s.lastUpdated = &clone + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newRPM := 60 + updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + RPMLimit: &newRPM, + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 60, updated.RPMLimit) + require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存") +} + +func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newName := "new" + sameRPM := 10 + _, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + Username: &newName, + RPMLimit: &sameRPM, + }) + require.NoError(t, err) + require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效") +} diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go index 25c66eb43746944899934e403cd4d96b89cdf7f6..02741d37ba8a960edc8b1e36bea893e99595beea 100644 --- a/backend/internal/service/announcement.go +++ b/backend/internal/service/announcement.go @@ -5,6 +5,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -34,8 +35,23 @@ const ( ) var ( - ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound - ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound + ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required") + ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid") + ErrAnnouncementContentRequired = infraerrors.BadRequest( + "ANNOUNCEMENT_CONTENT_REQUIRED", + "announcement content is required", + ) + ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid") + ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest( + "ANNOUNCEMENT_NOTIFY_MODE_INVALID", + "announcement notify_mode is invalid", + ) + ErrAnnouncementInvalidSchedule = infraerrors.BadRequest( + "ANNOUNCEMENT_TIME_RANGE_INVALID", + "starts_at must be before ends_at", + ) ) type AnnouncementTargeting = domain.AnnouncementTargeting diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go index c0a0681ac9e75eb9451a2e862a742e9c5823f20f..124790419b88ea55b3f4efef9041ec71f78b4938 100644 --- a/backend/internal/service/announcement_service.go +++ b/backend/internal/service/announcement_service.go @@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct { func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("create announcement: nil input") + return nil, ErrAnnouncementNilInput } title := strings.TrimSpace(input.Title) content := strings.TrimSpace(input.Content) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("create announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } if content == "" { - return nil, fmt.Errorf("create announcement: content is required") + return nil, ErrAnnouncementContentRequired } status := strings.TrimSpace(input.Status) @@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem status = AnnouncementStatusDraft } if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("create announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate() @@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem notifyMode = AnnouncementNotifyModeSilent } if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("create announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } if input.StartsAt != nil && input.EndsAt != nil { if !input.StartsAt.Before(*input.EndsAt) { - return nil, fmt.Errorf("create announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } @@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("update announcement: nil input") + return nil, ErrAnnouncementNilInput } a, err := s.announcementRepo.GetByID(ctx, id) @@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.Title != nil { title := strings.TrimSpace(*input.Title) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("update announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } a.Title = title } if input.Content != nil { content := strings.TrimSpace(*input.Content) if content == "" { - return nil, fmt.Errorf("update announcement: content is required") + return nil, ErrAnnouncementContentRequired } a.Content = content } if input.Status != nil { status := strings.TrimSpace(*input.Status) if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("update announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } a.Status = status } @@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.NotifyMode != nil { notifyMode := strings.TrimSpace(*input.NotifyMode) if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("update announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } a.NotifyMode = notifyMode } @@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if a.StartsAt != nil && a.EndsAt != nil { if !a.StartsAt.Before(*a.EndsAt) { - return nil, fmt.Errorf("update announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77fb9896e1728a046291fc337dc898b3d2e6e60a --- /dev/null +++ b/backend/internal/service/announcement_service_test.go @@ -0,0 +1,81 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type announcementRepoStub struct { + item *Announcement +} + +func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) { + if s.item == nil { + return nil, ErrAnnouncementNotFound + } + return s.item, nil +} + +func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (*announcementRepoStub) Delete(context.Context, int64) error { + return nil +} + +func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) { + return nil, nil +} + +func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{} + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + + _, err := svc.Create(context.Background(), &CreateAnnouncementInput{ + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + StartsAt: &now, + EndsAt: &now, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +} + +func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{ + item: &Announcement{ + ID: 1, + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + }, + } + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + startsAt := &now + endsAt := &now + + _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{ + StartsAt: &startsAt, + EndsAt: &endsAt, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index b1660ea709a24de436973618d3b0053afe9bcd6a..1a1c78b8de1739b98c6f7afed62aac778e89499f 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct { BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` TotalRecharged float64 `json:"total_recharged"` + + // RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。 + RPMLimit int `json:"rpm_limit"` + + // UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。 + // nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。 + UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"` } // APIKeyAuthGroupSnapshot 分组快照 @@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct { AllowMessagesDispatch bool `json:"allow_messages_dispatch"` DefaultMappedModel string `json:"default_mapped_model,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + + // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。 + RPMLimit int `json:"rpm_limit"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 2bd9a09188ac5b247af93e45d997871efb07cfec..974ea66efaddd98737d1903c09b5c76ee4fc847a 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold +const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot type apiKeyAuthCacheConfig struct { l1Size int @@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key - snapshot := s.snapshotFromAPIKey(apiKey) + snapshot := s.snapshotFromAPIKey(ctx, apiKey) if snapshot == nil { return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) } @@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn return s.snapshotToAPIKey(key, entry.Snapshot), true, nil } -func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { +func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot { if apiKey == nil || apiKey.User == nil { return nil } @@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, TotalRecharged: apiKey.User.TotalRecharged, + RPMLimit: apiKey.User.RPMLimit, }, } + + // 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。 + if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil { + override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID) + if err == nil && override != nil { + snapshot.User.UserGroupRPMOverride = override + } + // 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询 + } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ ID: apiKey.Group.ID, @@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, + RPMLimit: apiKey.Group.RPMLimit, } } return snapshot @@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, TotalRecharged: snapshot.User.TotalRecharged, + RPMLimit: snapshot.User.RPMLimit, + UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride, }, } if snapshot.Group != nil { @@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, + RPMLimit: snapshot.Group.RPMLimit, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 3c2f7dbb5c59edc9a87f7c2c9ab39733c36c6cdf..8cb1b8c42673a428c7d6c3b0e5fceb97cd4de41a 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t }, } - snapshot := svc.snapshotFromAPIKey(apiKey) + snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey) roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go new file mode 100644 index 0000000000000000000000000000000000000000..78f1185d5e08bb31e2aee5c1ce2c6500e811890b --- /dev/null +++ b/backend/internal/service/auth_email_binding.go @@ -0,0 +1,319 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// BindEmailIdentity verifies and binds a local email/password identity to the +// current user, or replaces the existing bound primary email. +func (s *AuthService) BindEmailIdentity( + ctx context.Context, + userID int64, + email string, + verifyCode string, + password string, +) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return nil, err + } + if isReservedEmail(normalizedEmail) { + return nil, ErrEmailReserved + } + if strings.TrimSpace(password) == "" { + return nil, ErrPasswordRequired + } + if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil { + return nil, err + } + + currentUser, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) + if firstRealEmailBind && len(password) < 6 { + return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters") + } + if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) { + return nil, ErrPasswordIncorrect + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return nil, ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return nil, ErrServiceUnavailable + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, fmt.Errorf("hash password: %w", err) + } + + if s.entClient != nil { + if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil { + return nil, err + } + s.revokeEmailIdentitySessions(ctx, userID) + return currentUser, nil + } + + currentUser.Email = normalizedEmail + currentUser.PasswordHash = hashedPassword + if err := s.userRepo.Update(ctx, currentUser); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, ErrEmailExists + } + return nil, ErrServiceUnavailable + } + + if firstRealEmailBind { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil { + return nil, fmt.Errorf("apply email first bind defaults: %w", err) + } + } + + s.revokeEmailIdentitySessions(ctx, userID) + return currentUser, nil +} + +// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows. +func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error { + if s == nil { + return ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return err + } + if isReservedEmail(normalizedEmail) { + return ErrEmailReserved + } + if s.emailService == nil { + return ErrServiceUnavailable + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + if errors.Is(err, ErrUserNotFound) { + return ErrUserNotFound + } + return ErrServiceUnavailable + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName) +} + +func normalizeEmailForIdentityBinding(email string) (string, error) { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || len(normalized) > 255 { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(normalized); err != nil { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + return normalized, nil +} + +func hasBindableEmailIdentitySubject(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return normalized != "" && !isReservedEmail(normalized) +} + +func (s *AuthService) updateBoundEmailIdentityTx( + ctx context.Context, + currentUser *User, + email string, + hashedPassword string, + applyFirstBindDefaults bool, +) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return ErrServiceUnavailable + } + return nil +} + +func (s *AuthService) updateBoundEmailIdentityWithClient( + ctx context.Context, + client *dbent.Client, + currentUser *User, + email string, + hashedPassword string, + applyFirstBindDefaults bool, +) error { + if client == nil || currentUser == nil || currentUser.ID <= 0 { + return ErrServiceUnavailable + } + + oldEmail := currentUser.Email + if _, err := client.User.UpdateOneID(currentUser.ID). + SetEmail(email). + SetPasswordHash(hashedPassword). + Save(ctx); err != nil { + if dbent.IsConstraintError(err) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil { + if errors.Is(err, ErrEmailExists) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if applyFirstBindDefaults { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil { + return fmt.Errorf("apply email first bind defaults: %w", err) + } + } + + updatedUser, err := client.User.Get(ctx, currentUser.ID) + if err != nil { + return ErrServiceUnavailable + } + currentUser.Email = updatedUser.Email + currentUser.PasswordHash = updatedUser.PasswordHash + currentUser.Balance = updatedUser.Balance + currentUser.Concurrency = updatedUser.Concurrency + currentUser.UpdatedAt = updatedUser.UpdatedAt + return nil +} + +func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) { + if err := s.RevokeAllUserSessions(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err) + } +} + +func replaceBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + oldEmail string, + newEmail string, + source string, +) error { + newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail) + if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil { + return err + } + + oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail) + if oldSubject == "" || oldSubject == newSubject { + return nil + } + + _, err := client.AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(oldSubject), + ). + Exec(ctx) + return err +} + +func ensureBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + subject string, + source string, +) error { + if client == nil || userID <= 0 || subject == "" { + return nil + } + + if strings.TrimSpace(source) == "" { + source = "auth_service_email_bind" + } + + if err := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(subject). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": strings.TrimSpace(source)}). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return err + } + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(subject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity.UserID != userID { + return ErrEmailExists + } + return nil +} + +func normalizeBoundEmailAuthIdentitySubject(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || isReservedEmail(normalized) { + return "" + } + return normalized +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go new file mode 100644 index 0000000000000000000000000000000000000000..a18cf39ccb62b7d9b188573d87f402821887dec8 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -0,0 +1,385 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" +) + +func normalizeOAuthSignupSource(signupSource string) string { + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + switch signupSource { + case "", "email": + return "email" + case "linuxdo", "wechat", "oidc": + return signupSource + default: + return "email" + } +} + +// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth +// account-creation flows without relying on the public registration gate. +func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + return nil, ErrEmailVerifyRequired + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, ErrEmailVerifyRequired + } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + if s == nil || s.emailService == nil { + return nil, ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil { + return nil, err + } + return &SendVerifyCodeResult{ + Countdown: int(verifyCodeCooldown / time.Second), + }, nil +} + +func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil, nil + } + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { + return nil, ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" { + return nil, ErrInvitationCodeRequired + } + + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + return nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, ErrInvitationCodeInvalid + } + return redeemCode, nil +} + +// VerifyOAuthEmailCode verifies the locally entered email verification code for +// third-party signup and binding flows. This is intentionally independent from +// the global registration email verification toggle. +func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error { + email = strings.TrimSpace(strings.ToLower(email)) + verifyCode = strings.TrimSpace(verifyCode) + + if email == "" { + return ErrEmailVerifyRequired + } + if verifyCode == "" { + return ErrEmailVerifyRequired + } + if s == nil || s.emailService == nil { + return ErrServiceUnavailable + } + return s.emailService.VerifyCode(ctx, email, verifyCode) +} + +// RegisterOAuthEmailAccount creates a local account from a third-party first +// login after the user has verified a local email address. +func (s *AuthService) RegisterOAuthEmailAccount( + ctx context.Context, + email string, + password string, + verifyCode string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + return nil, nil, err + } + + if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil { + return nil, nil, err + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = normalizeOAuthSignupSource(signupSource) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + Status: StatusActive, + SignupSource: signupSource, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "") + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap +// only after the pending OAuth flow has fully reached its last reversible step. +func (s *AuthService) FinalizeOAuthEmailAccount( + ctx context.Context, + user *User, + invitationCode string, + signupSource string, +) error { + if s == nil || user == nil || user.ID <= 0 { + return ErrServiceUnavailable + } + + signupSource = normalizeOAuthSignupSource(signupSource) + invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + return err + } + if invitationRedeemCode != nil { + if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return ErrInvitationCodeInvalid + } + } + + s.updateOAuthSignupSource(ctx, user.ID, signupSource) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + return nil +} + +// RollbackOAuthEmailAccountCreation removes a partially-created local account +// and restores any invitation code already consumed by that account. +func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error { + if s == nil || s.userRepo == nil || userID <= 0 { + return ErrServiceUnavailable + } + if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil { + return err + } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete created oauth user: %w", err) + } + return nil +} + +func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil + } + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { + return ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" || userID <= 0 { + return nil + } + + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + if errors.Is(err, ErrRedeemCodeNotFound) { + return nil + } + return fmt.Errorf("load invitation code: %w", err) + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID { + return nil + } + + redeemCode.Status = StatusUnused + redeemCode.UsedBy = nil + redeemCode.UsedAt = nil + if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil { + return fmt.Errorf("restore invitation code: %w", err) + } + return nil +} + +func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client { + if s == nil || s.entClient == nil { + return nil + } + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return s.entClient +} + +func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if client := s.oauthEmailFlowClient(ctx); client != nil { + entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrRedeemCodeNotFound + } + return nil, err + } + return &RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: oauthEmailFlowStringValue(entity.Notes), + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil + } + return s.redeemRepo.GetByCode(ctx, invitationCode) +} + +func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error { + if client := s.oauthEmailFlowClient(ctx); client != nil { + affected, err := client.RedeemCode.Update(). + Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)). + SetStatus(StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return ErrRedeemCodeUsed + } + return nil + } + return s.redeemRepo.Use(ctx, invitationID, userID) +} + +func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + if client := s.oauthEmailFlowClient(ctx); client != nil { + update := client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err + } + return s.redeemRepo.Update(ctx, code) +} + +func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) { + client := s.oauthEmailFlowClient(ctx) + if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" { + return + } + _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx) +} + +func oauthEmailFlowStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} + +// ValidatePasswordCredentials checks the local password without completing the +// login flow. This is used by pending third-party account adoption flows before +// the external identity has been bound. +func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email))) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, ErrInvalidCredentials + } + return nil, ErrServiceUnavailable + } + if !user.IsActive() { + return nil, ErrUserNotActive + } + if !s.CheckPassword(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + return user, nil +} + +// RecordSuccessfulLogin updates last-login activity after a non-standard login +// flow finishes with a real session. +func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + if s != nil && s.userRepo != nil && userID > 0 { + user, err := s.userRepo.GetByID(ctx, userID) + if err == nil && user != nil && !isReservedEmail(user.Email) { + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) + } + } + s.touchUserLogin(ctx, userID) +} diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e3fb2f8528069b45c8dd9cd6468099af084e0ae2 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -0,0 +1,325 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type redeemCodeRepoStub struct { + codesByCode map[string]*RedeemCode + useCalls []struct { + id int64 + userID int64 + } + updateCalls []*RedeemCode +} + +func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error { + panic("unexpected Create call") +} + +func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) { + if s.codesByCode == nil { + return nil, ErrRedeemCodeNotFound + } + redeemCode, ok := s.codesByCode[code] + if !ok { + return nil, ErrRedeemCodeNotFound + } + cloned := *redeemCode + return &cloned, nil +} + +func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + cloned := *code + s.updateCalls = append(s.updateCalls, &cloned) + if s.codesByCode == nil { + s.codesByCode = make(map[string]*RedeemCode) + } + s.codesByCode[cloned.Code] = &cloned + return nil +} + +func (s *redeemCodeRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error { + for code, redeemCode := range s.codesByCode { + if redeemCode.ID != id { + continue + } + now := time.Now().UTC() + redeemCode.Status = StatusUsed + redeemCode.UsedBy = &userID + redeemCode.UsedAt = &now + s.codesByCode[code] = redeemCode + s.useCalls = append(s.useCalls, struct { + id int64 + userID int64 + }{id: id, userID: userID}) + return nil + } + return ErrRedeemCodeNotFound +} + +func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +func newOAuthEmailFlowAuthService( + userRepo UserRepository, + redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, + settings map[string]string, + emailCache EmailCache, +) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache) + + return NewAuthService( + nil, + userRepo, + redeemRepo, + refreshTokenCache, + cfg, + settingService, + emailService, + nil, + nil, + nil, + nil, + ) +} + +func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUnused, + }, + }, + } + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + nil, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "INVITE123", + "oidc", + ) + + require.Nil(t, tokenPair) + require.Nil(t, user) + require.Error(t, err) + require.Contains(t, err.Error(), "generate token pair") + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, userRepo.created, 1) + require.Empty(t, redeemRepo.useCalls) + require.Empty(t, redeemRepo.updateCalls) +} + +func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "", + " OIDC ", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "oidc", userRepo.created[0].SignupSource) +} + +func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) { + userRepo := &userRepoStub{nextID: 43} + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fallback@example.com", + "secret-123", + "246810", + "", + "github", + ) + + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Len(t, userRepo.created, 1) + require.Equal(t, "email", userRepo.created[0].SignupSource) +} + +func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { + userRepo := &userRepoStub{} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUsed, + UsedBy: func() *int64 { + v := int64(42) + return &v + }(), + UsedAt: func() *time.Time { + v := time.Now().UTC() + return &v + }(), + }, + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123") + + require.NoError(t, err) + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, redeemRepo.updateCalls, 1) + require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status) + require.Nil(t, redeemRepo.updateCalls[0].UsedBy) + require.Nil(t, redeemRepo.updateCalls[0].UsedAt) +} + +func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { + userRepo := &userRepoStub{deleteErr: errors.New("delete failed")} + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "") + + require.Error(t, err) + require.Contains(t, err.Error(), "delete created oauth user") +} diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go new file mode 100644 index 0000000000000000000000000000000000000000..aa06e59f3079a02ae7d9716bdc7c91029fe4d751 --- /dev/null +++ b/backend/internal/service/auth_oauth_first_bind.go @@ -0,0 +1,104 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + + entsql "entgo.io/ent/dialect/sql" +) + +// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap +// settings the first time a user binds a third-party identity. The grant is +// idempotent per user/provider pair. +func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 { + return nil + } + + if dbent.TxFromContext(ctx) != nil { + return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin first bind defaults transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil { + return err + } + return tx.Commit() +} + +func (s *AuthService) applyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true) + if err != nil { + return fmt.Errorf("load auth source defaults: %w", err) + } + if !enabled { + return nil + } + + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() + } + + var result entsql.Result + if err := client.Driver().Exec( + ctx, + `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + []any{userID, strings.TrimSpace(providerType), "first_bind"}, + &result, + ); err != nil { + return fmt.Errorf("record first bind provider grant: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read first bind provider grant result: %w", err) + } + if affected == 0 { + return nil + } + + if providerDefaults.Balance != 0 { + if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind balance default: %w", err) + } + } + if providerDefaults.Concurrency != 0 { + if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind concurrency default: %w", err) + } + } + if s.defaultSubAssigner != nil { + for _, item := range providerDefaults.Subscriptions { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by first bind defaults", + }); err != nil { + return fmt.Errorf("apply first bind subscription default: %w", err) + } + } + } + + return nil +} diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go new file mode 100644 index 0000000000000000000000000000000000000000..6e69c121f3c111b31085dc87d40dc51c298902cb --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service.go @@ -0,0 +1,543 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash/fnv" + "sort" + "strings" + "sync" + "time" + + "entgo.io/ent/dialect" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + + entsql "entgo.io/ent/dialect/sql" +) + +var ( + ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") + ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") + ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") + ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") + ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") + ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") + ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") +) + +const ( + defaultPendingAuthTTL = 15 * time.Minute + defaultPendingAuthCompletionTTL = 5 * time.Minute +) + +type PendingAuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type CreatePendingAuthSessionInput struct { + SessionToken string + Intent string + Identity PendingAuthIdentityKey + TargetUserID *int64 + RedirectTo string + ResolvedEmail string + RegistrationPasswordHash string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + LocalFlowState map[string]any + ExpiresAt time.Time +} + +type IssuePendingAuthCompletionCodeInput struct { + PendingAuthSessionID int64 + BrowserSessionKey string + TTL time.Duration +} + +type IssuePendingAuthCompletionCodeResult struct { + Code string + ExpiresAt time.Time +} + +type PendingIdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type AuthPendingIdentityService struct { + entClient *dbent.Client +} + +var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry() + +type authPendingIdentityScopedKeyLockRegistry struct { + mu sync.Mutex + locks map[string]*authPendingIdentityScopedKeyLockEntry +} + +type authPendingIdentityScopedKeyLockEntry struct { + mu sync.Mutex + refs int +} + +func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry { + return &authPendingIdentityScopedKeyLockRegistry{ + locks: make(map[string]*authPendingIdentityScopedKeyLockEntry), + } +} + +func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() { + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 { + return func() {} + } + + entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized)) + r.mu.Lock() + for _, key := range normalized { + entry := r.locks[key] + if entry == nil { + entry = &authPendingIdentityScopedKeyLockEntry{} + r.locks[key] = entry + } + entry.refs++ + entries = append(entries, entry) + } + r.mu.Unlock() + + for _, entry := range entries { + entry.mu.Lock() + } + + return func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].mu.Unlock() + } + + r.mu.Lock() + defer r.mu.Unlock() + for idx, key := range normalized { + entry := entries[idx] + entry.refs-- + if entry.refs == 0 { + delete(r.locks, key) + } + } + } +} + +func normalizeAuthPendingIdentityLockKeys(keys ...string) []string { + if len(keys) == 0 { + return nil + } + + deduped := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + continue + } + deduped[trimmed] = struct{}{} + } + if len(deduped) == 0 { + return nil + } + + normalized := make([]string, 0, len(deduped)) + for key := range deduped { + normalized = append(normalized, key) + } + sort.Strings(normalized) + return normalized +} + +func authPendingIdentityAdvisoryLockHash(key string) int64 { + hasher := fnv.New64a() + _, _ = hasher.Write([]byte(key)) + return int64(hasher.Sum64()) +} + +func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) { + release := authPendingIdentityScopedKeyLocks.lock(keys...) + normalized := normalizeAuthPendingIdentityLockKeys(keys...) + if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres { + return release, nil + } + + for _, key := range normalized { + var rows entsql.Rows + if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil { + release() + return nil, err + } + _ = rows.Close() + } + + return release, nil +} + +func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { + keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)} + if identityID != nil && *identityID > 0 { + keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID)) + } + return keys +} + +func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { + return &AuthPendingIdentityService{entClient: entClient} +} + +func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken := strings.TrimSpace(input.SessionToken) + if sessionToken == "" { + var err error + sessionToken, err = randomOpaqueToken(24) + if err != nil { + return nil, err + } + } + + expiresAt := input.ExpiresAt.UTC() + if expiresAt.IsZero() { + expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) + } + + create := s.entClient.PendingAuthSession.Create(). + SetSessionToken(sessionToken). + SetIntent(strings.TrimSpace(input.Intent)). + SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). + SetRedirectTo(strings.TrimSpace(input.RedirectTo)). + SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). + SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). + SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). + SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). + SetLocalFlowState(copyPendingMap(input.LocalFlowState)). + SetExpiresAt(expiresAt) + if input.TargetUserID != nil { + create = create.SetTargetUserID(*input.TargetUserID) + } + return create.Save(ctx) +} + +func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + + code, err := randomOpaqueToken(24) + if err != nil { + return nil, err + } + ttl := input.TTL + if ttl <= 0 { + ttl = defaultPendingAuthCompletionTTL + } + expiresAt := time.Now().UTC().Add(ttl) + + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeHash(hashPendingAuthCode(code)). + SetCompletionCodeExpiresAt(expiresAt) + if strings.TrimSpace(input.BrowserSessionKey) != "" { + update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) + } + if _, err := update.Save(ctx); err != nil { + return nil, err + } + + return &IssuePendingAuthCompletionCodeResult{ + Code: code, + ExpiresAt: expiresAt, + }, nil +} + +func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthCodeInvalid + } + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) +} + +func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) +} + +func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken = strings.TrimSpace(sessionToken) + if sessionToken == "" { + return nil, ErrPendingAuthSessionNotFound + } + + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) consumeSession( + ctx context.Context, + session *dbent.PendingAuthSession, + browserSessionKey string, + expiredErr error, + consumedErr error, +) (*dbent.PendingAuthSession, error) { + if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + + sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState) + now := time.Now().UTC() + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + Where( + pendingauthsession.ConsumedAtIsNil(), + pendingauthsession.ExpiresAtGTE(now), + pendingauthsession.Or( + pendingauthsession.CompletionCodeExpiresAtIsNil(), + pendingauthsession.CompletionCodeExpiresAtGTE(now), + ), + ). + SetConsumedAt(now). + SetLocalFlowState(sanitizedLocalFlowState). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt() + if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { + update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey)) + } + updated, err := update.Save(ctx) + if err == nil { + return updated, nil + } + if !dbent.IsNotFound(err) { + return nil, err + } + + current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID) + if currentErr != nil { + if dbent.IsNotFound(currentErr) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, currentErr + } + if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + return nil, consumedErr +} + +func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any { + sanitized := copyPendingMap(localFlowState) + if len(sanitized) == 0 { + return sanitized + } + + rawCompletion, ok := sanitized["completion_response"] + if !ok { + return sanitized + } + completion, ok := rawCompletion.(map[string]any) + if !ok { + return sanitized + } + + cleanedCompletion := copyPendingMap(completion) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(cleanedCompletion, key) + } + sanitized["completion_response"] = cleanedCompletion + return sanitized +} + +func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { + if session == nil { + return ErrPendingAuthSessionNotFound + } + + now := time.Now().UTC() + if session.ConsumedAt != nil { + return consumedErr + } + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + return expiredErr + } + if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { + return expiredErr + } + if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return ErrPendingAuthBrowserMismatch + } + return nil +} + +func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + tx, err := s.entClient.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return nil, err + } + + client := s.entClient + txCtx := ctx + if err == nil { + defer func() { _ = tx.Rollback() }() + client = tx.Client() + txCtx = dbent.NewTxContext(ctx, tx) + } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil { + client = existingTx.Client() + } + + releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...) + if err != nil { + return nil, err + } + defer releaseLocks() + + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(txCtx); err != nil { + return nil, err + } + } + + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil && *input.IdentityID > 0 { + create = create.SetIdentityID(*input.IdentityID) + } + + decisionID, err := create. + OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). + UpdateNewValues(). + ID(txCtx) + if err != nil { + return nil, err + } + + decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID) + if err != nil { + return nil, err + } + + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, err + } + } + + return decision, nil +} + +func copyPendingMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func randomOpaqueToken(byteLen int) (string, error) { + if byteLen <= 0 { + byteLen = 16 + } + buf := make([]byte, byteLen) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func hashPendingAuthCode(code string) string { + sum := sha256.Sum256([]byte(code)) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..555bb0e7707af41b52ca9ff06561af116bc09ef1 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -0,0 +1,526 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return NewAuthPendingIdentityService(client), client +} + +func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("pending-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + TargetUserID: &targetUser.ID, + RedirectTo: "/profile", + ResolvedEmail: "user@example.com", + BrowserSessionKey: "browser-1", + UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"}, + LocalFlowState: map[string]any{"step": "email_required"}, + }) + require.NoError(t, err) + require.NotEmpty(t, session.SessionToken) + require.Equal(t, "bind_current_user", session.Intent) + require.Equal(t, "wechat", session.ProviderType) + require.NotNil(t, session.TargetUserID) + require.Equal(t, targetUser.ID, *session.TargetUserID) + require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"]) + require.Equal(t, "email_required", session.LocalFlowState["step"]) +} + +func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expected", + UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"}, + LocalFlowState: map[string]any{"step": "pending"}, + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expected", + }) + require.NoError(t, err) + require.NotEmpty(t, issued.Code) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + require.Empty(t, consumed.CompletionCodeHash) + require.Nil(t, consumed.CompletionCodeExpiresAt) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.ErrorIs(t, err, ErrPendingAuthCodeInvalid) +} + +func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expired", + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expired", + TTL: time.Second, + }) + require.NoError(t, err) + + _, err = client.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired") + require.ErrorIs(t, err, ErrPendingAuthCodeExpired) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + require.NoError(t, err) + + first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.True(t, first.AdoptDisplayName) + require.False(t, first.AdoptAvatar) + require.Nil(t, first.IdentityID) + + second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + require.NotNil(t, second.IdentityID) + require.Equal(t, identity.ID, *second.IdentityID) + require.True(t, second.AdoptAvatar) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption-reassign@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-reassign"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-reassign", + }, + }) + require.NoError(t, err) + + firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: firstSession.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.NotNil(t, firstDecision.IdentityID) + require.Equal(t, identity.ID, *firstDecision.IdentityID) + + secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-reassign", + }, + }) + require.NoError(t, err) + + secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: secondSession.ID, + IdentityID: &identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.NotNil(t, secondDecision.IdentityID) + require.Equal(t, identity.ID, *secondDecision.IdentityID) + + reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedFirst.IdentityID) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption-concurrent@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-concurrent"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-concurrent", + }, + }) + require.NoError(t, err) + + firstCreateStarted := make(chan struct{}) + releaseFirstCreate := make(chan struct{}) + var firstCreate sync.Once + client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator { + return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) { + blocked := false + if m.Op().Is(dbent.OpCreate) { + firstCreate.Do(func() { + blocked = true + close(firstCreateStarted) + }) + } + if blocked { + <-releaseFirstCreate + } + return next.Mutate(ctx, m) + }) + }) + + type adoptionResult struct { + decision *dbent.IdentityAdoptionDecision + err error + } + + input := PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + } + + results := make(chan adoptionResult, 2) + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + <-firstCreateStarted + + go func() { + decision, err := svc.UpsertAdoptionDecision(ctx, input) + results <- adoptionResult{decision: decision, err: err} + }() + + time.Sleep(100 * time.Millisecond) + close(releaseFirstCreate) + + first := <-results + second := <-results + + require.NoError(t, first.err) + require.NoError(t, second.err) + require.NotNil(t, first.decision) + require.NotNil(t, second.decision) + require.Equal(t, first.decision.ID, second.decision.ID) + + count, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, count) + + loaded, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, loaded.IdentityID) + require.Equal(t, identity.ID, *loaded.IdentityID) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) { + t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL") + + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("legacy-null-session@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("legacy-null-session"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + _, err = client.ExecContext( + ctx, + `INSERT INTO identity_adoption_decisions + (identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id) + VALUES (?, ?, ?, ?, ?, ?, NULL)`, + identity.ID, + true, + false, + time.Now().UTC(), + time.Now().UTC(), + time.Now().UTC(), + ) + require.NoError(t, err) + legacyDecision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.IdentityIDEQ(identity.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, legacyDecision.IdentityID) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "legacy-null-session", + }, + }) + require.NoError(t, err) + + decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + + reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID) + require.NoError(t, err) + require.Nil(t, reloadedLegacy.IdentityID) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "subject-session-token", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "token", + }, + }, + }) + require.NoError(t, err) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "stale-replay-subject", + }, + BrowserSessionKey: "browser-session", + }) + require.NoError(t, err) + + loaded, err := svc.getBrowserSession(ctx, session.SessionToken) + require.NoError(t, err) + + consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "legacy-token-subject", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }, + }) + require.NoError(t, err) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + stored, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + + completion, ok := stored.LocalFlowState["completion_response"].(map[string]any) + require.True(t, ok) + require.NotContains(t, completion, "access_token") + require.NotContains(t, completion, "refresh_token") + require.NotContains(t, completion, "expires_in") + require.NotContains(t, completion, "token_type") + require.Equal(t, "/dashboard", completion["redirect"]) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 847f94a8edccbf73be5abc5a5169284952d81c3f..8c06476d1ad6a8964b1e9402aa9b77b80c314783 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -77,6 +79,12 @@ type DefaultSubscriptionAssigner interface { AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } +type signupGrantPlan struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting +} + // NewAuthService 创建认证服务实例 func NewAuthService( entClient *dbent.Client, @@ -106,6 +114,13 @@ func NewAuthService( } } +func (s *AuthService) EntClient() *dbent.Client { + if s == nil { + return nil + } + return s.entClient +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") @@ -179,12 +194,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, fmt.Errorf("hash password: %w", err) } - // 获取默认配置 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency + grantPlan := s.resolveSignupGrantPlan(ctx, "email") + + // 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。 + var defaultRPMLimit int if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) } // 创建用户 @@ -192,8 +207,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw Email: email, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, } @@ -205,7 +221,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, "email", true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -469,12 +486,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return "", nil, fmt.Errorf("hash password: %w", err) } - // 新用户默认值。 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) } newUser := &User{ @@ -482,9 +498,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, + SignupSource: signupSource, } if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -501,7 +519,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -520,7 +539,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) @@ -584,11 +602,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) } newUser := &User{ @@ -596,9 +614,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, + SignupSource: signupSource, } if s.entClient != nil && invitationRedeemCode != nil { @@ -630,7 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -646,7 +667,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -670,7 +692,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { return nil, nil, fmt.Errorf("generate token pair: %w", err) @@ -678,80 +699,273 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. -const pendingOAuthTokenTTL = 10 * time.Minute +func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: notes, + }); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + +func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan { + plan := signupGrantPlan{} + if s != nil && s.cfg != nil { + plan.Balance = s.cfg.Default.UserBalance + plan.Concurrency = s.cfg.Default.UserConcurrency + } + if s == nil || s.settingService == nil { + return plan + } + + plan.Balance = s.settingService.GetDefaultBalance(ctx) + plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) + plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) -// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. -const pendingOAuthPurpose = "pending_oauth_registration" + resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) + return plan + } + if !enabled { + return plan + } -type pendingOAuthClaims struct { - Email string `json:"email"` - Username string `json:"username"` - Purpose string `json:"purpose"` - jwt.RegisteredClaims + plan.Balance = resolved.Balance + plan.Concurrency = resolved.Concurrency + plan.Subscriptions = resolved.Subscriptions + return plan } -// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity -// while waiting for the user to supply an invitation code. -func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { - now := time.Now() - claims := &pendingOAuthClaims{ - Email: email, - Username: username, - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, +func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) { + if defaults == nil { + return ProviderDefaultGrantSettings{}, false + } + + switch strings.ToLower(strings.TrimSpace(signupSource)) { + case "email": + return defaults.Email, true + case "linuxdo": + return defaults.LinuxDo, true + case "oidc": + return defaults.OIDC, true + case "wechat": + return defaults.WeChat, true + default: + return ProviderDefaultGrantSettings{}, false } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.Secret)) } -// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. -// Returns ErrInvalidToken when the token is invalid or expired. -func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { - if len(tokenStr) > maxTokenLength { - return "", "", ErrInvalidToken +func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { + if user == nil || user.ID <= 0 { + return } - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return []byte(s.cfg.JWT.Secret), nil - }) - if parseErr != nil { - return "", "", ErrInvalidToken + + if strings.TrimSpace(signupSource) == "" { + signupSource = "email" } - claims, ok := token.Claims.(*pendingOAuthClaims) - if !ok || !token.Valid { - return "", "", ErrInvalidToken + s.updateUserSignupSource(ctx, user.ID, signupSource) + + if touchLogin { + s.touchUserLogin(ctx, user.ID) } - if claims.Purpose != pendingOAuthPurpose { - return "", "", ErrInvalidToken +} + +func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + if strings.TrimSpace(signupSource) == "" { + return + } + if err := s.entClient.User.UpdateOneID(userID). + SetSignupSource(signupSource). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err) } - return claims.Email, claims.Username, nil } -func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { - if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { +func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { + if s == nil || s.entClient == nil || userID <= 0 { return } - items := s.settingService.GetDefaultSubscriptions(ctx) - for _, item := range items { - if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ - UserID: userID, - GroupID: item.GroupID, - ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", - }); err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + now := time.Now().UTC() + if err := s.entClient.User.UpdateOneID(userID). + SetLastLoginAt(now). + SetLastActiveAt(now). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err) + } +} + +func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) { + if s == nil || user == nil || user.ID <= 0 { + return + } + identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill") + if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err) } } } +func (s *AuthService) shouldApplyEmailFirstBindDefaults( + ctx context.Context, + userID int64, + identity *dbent.AuthIdentity, + created bool, +) bool { + source := emailAuthIdentitySource(identity.Metadata) + if source == "auth_service_login_backfill" { + return false + } + if created { + return true + } + if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID { + return false + } + if source != "auth_service_dual_write" { + return false + } + + hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind") + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err) + return false + } + return !hasGrant +} + +func emailAuthIdentitySource(metadata map[string]any) string { + if len(metadata) == 0 { + return "" + } + raw, ok := metadata["source"] + if !ok { + return "" + } + return strings.TrimSpace(fmt.Sprint(raw)) +} + +func (s *AuthService) hasProviderGrantRecord( + ctx context.Context, + userID int64, + providerType string, + grantReason string, +) (bool, error) { + if s == nil || s.entClient == nil || userID <= 0 { + return false, nil + } + + rows, err := s.entClient.QueryContext( + ctx, + `SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`, + userID, + strings.TrimSpace(providerType), + strings.TrimSpace(grantReason), + ) + if err != nil { + return false, err + } + defer func() { _ = rows.Close() }() + return rows.Next(), rows.Err() +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return nil, false + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return nil, false + } + if strings.TrimSpace(source) == "" { + source = "auth_service_dual_write" + } + + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() + } + + buildQuery := func() *dbent.AuthIdentityQuery { + return client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(email), + ) + } + + existed, err := buildQuery().Exist(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + + if !existed { + if err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": strings.TrimSpace(source), + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + if isSQLNoRowsError(err) { + return nil, false + } + } + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + } + + identity, err := buildQuery().Only(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return nil, false + } + if identity.UserID != user.ID { + logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID) + return nil, false + } + + return identity, !existed +} + +func inferLegacySignupSource(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + switch { + case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): + return "linuxdo" + case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): + return "oidc" + case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain): + return "wechat" + default: + return "email" + } +} + func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { if s.settingService == nil { return nil @@ -834,7 +1048,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token @@ -853,7 +1068,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { UserID: user.ID, Email: user.Email, Role: user.Role, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expiresAt), IssuedAt: jwt.NewNumericDate(now), @@ -919,7 +1134,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( // Security: Check TokenVersion to prevent refreshing revoked tokens // This ensures tokens issued before a password change cannot be refreshed - if claims.TokenVersion != user.TokenVersion { + if claims.TokenVersion != resolvedTokenVersion(user) { return "", ErrTokenRevoked } @@ -1147,7 +1362,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami data := &RefreshTokenData{ UserID: user.ID, - TokenVersion: user.TokenVersion, + TokenVersion: resolvedTokenVersion(user), FamilyID: familyID, CreatedAt: now, ExpiresAt: now.Add(ttl), @@ -1227,7 +1442,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) } // 检查TokenVersion(密码更改后所有Token失效) - if data.TokenVersion != user.TokenVersion { + if data.TokenVersion != resolvedTokenVersion(user) { // TokenVersion不匹配,撤销整个Token家族 _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrTokenRevoked @@ -1272,8 +1487,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) } +// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions. +// Access/refresh token verification both depend on TokenVersion, so bumping it provides +// immediate revocation even if refresh-token cache cleanup later fails. +func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + + user.TokenVersion++ + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user: %w", err) + } + + if err := s.RevokeAllUserSessions(ctx, userID); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err) + } + return nil +} + // hashToken 计算Token的SHA256哈希 func hashToken(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } + +func resolvedTokenVersion(user *User) int64 { + if user == nil { + return 0 + } + if user.TokenVersionResolved { + return user.TokenVersion + } + + material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash + sum := sha256.Sum256([]byte(material)) + fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff) + return user.TokenVersion ^ fingerprint +} diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cced842a4d0ebd8ded57fd0422bb272fa248796c --- /dev/null +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -0,0 +1,853 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type emailBindDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + +type flakyEmailBindDefaultSubAssignerStub struct { + err error + calls []*service.AssignSubscriptionInput +} + +func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return nil, false, s.err +} + +func newAuthServiceForEmailBind( + t *testing.T, + settings map[string]string, + emailCache service.EmailCache, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil) +} + +func newAuthServiceForEmailBindWithRefreshCache( + t *testing.T, + settings map[string]string, + emailCache service.EmailCache, + defaultSubAssigner service.DefaultSubscriptionAssigner, + refreshTokenCache service.RefreshTokenCache, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-bind-email-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingRepo := &emailBindSettingRepoStub{values: settings} + settingSvc := service.NewSettingService(settingRepo, cfg) + + var emailSvc *service.EmailService + if emailCache != nil { + emailSvc = service.NewEmailService(settingRepo, emailCache) + } + + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + return svc, repo, client +} + +func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain). + SetUsername("legacy-user"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "newemail@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "newemail@example.com", storedUser.Email) + require.Equal(t, 11.0, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("newemail@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, user.ID, assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + sourceUser, err := client.User.Create(). + SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.User.Create(). + SetEmail("taken@example.com"). + SetUsername("taken-user"). + SetPasswordHash("hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailExists) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, sourceUser.ID) + require.NoError(t, err) + require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email) + require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) { + assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain + user, err := client.User.Create(). + SetEmail(originalEmail). + SetUsername("legacy-rollback"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password") + require.ErrorContains(t, err, "apply email first bind defaults") + require.ErrorContains(t, err, "temporary assign failure") + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, originalEmail, storedUser.Email) + require.Equal(t, "old-hash", storedUser.PasswordHash) + require.Equal(t, 2.5, storedUser.Balance) + require.Equal(t, 1, storedUser.Concurrency) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("rollback@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("source-user@example.com"). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailReserved) + require.Nil(t, updatedUser) +} + +func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(7.5). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "new@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "new@example.com", storedUser.Email) + require.Equal(t, 7.5, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, newIdentityCount) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, oldIdentityCount) + + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password") + require.ErrorIs(t, err, service.ErrPasswordIncorrect) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "current@example.com", storedUser.Email) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, oldIdentityCount) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, newIdentityCount) +} + +func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) { + ctx := context.Background() + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + refreshTokenCache := newEmailBindRefreshTokenCacheStub() + userRepo := newEmailBindUserRepoStub(&service.User{ + ID: 41, + Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain, + Username: "legacy-user", + PasswordHash: "old-hash", + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-bind-email-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + } + emailService := service.NewEmailService(nil, cache) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil) + + oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ + ID: 41, + Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain, + Role: service.RoleUser, + Status: service.StatusActive, + TokenVersion: 4, + }, "") + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + + storedUser, err := userRepo.GetByID(ctx, 41) + require.NoError(t, err) + require.Equal(t, "new@example.com", storedUser.Email) + require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash)) + + _, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken) + require.ErrorIs(t, err, service.ErrTokenRevoked) + + _, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken) + require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid)) +} + +type emailBindSettingRepoStub struct { + values map[string]string +} + +func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil +} + +func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *emailBindSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +type emailBindCacheStub struct { + data *service.VerificationCodeData + err error +} + +func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + if s.err != nil { + return nil, s.err + } + return s.data, nil +} + +func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +type emailBindRefreshTokenCacheStub struct { + mu sync.Mutex + tokens map[string]*service.RefreshTokenData + userSets map[int64]map[string]struct{} + families map[string]map[string]struct{} +} + +func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub { + return &emailBindRefreshTokenCacheStub{ + tokens: make(map[string]*service.RefreshTokenData), + userSets: make(map[int64]map[string]struct{}), + families: make(map[string]map[string]struct{}), + } +} + +func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + cloned := *data + s.tokens[tokenHash] = &cloned + return nil +} + +func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) { + s.mu.Lock() + defer s.mu.Unlock() + data, ok := s.tokens[tokenHash] + if !ok { + return nil, service.ErrRefreshTokenNotFound + } + cloned := *data + return &cloned, nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tokens, tokenHash) + for _, tokenSet := range s.userSets { + delete(tokenSet, tokenHash) + } + for _, tokenSet := range s.families { + delete(tokenSet, tokenHash) + } + return nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for tokenHash := range s.userSets[userID] { + delete(s.tokens, tokenHash) + for _, tokenSet := range s.families { + delete(tokenSet, tokenHash) + } + } + delete(s.userSets, userID) + return nil +} + +func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for tokenHash := range s.families[familyID] { + delete(s.tokens, tokenHash) + for _, tokenSet := range s.userSets { + delete(tokenSet, tokenHash) + } + } + delete(s.families, familyID) + return nil +} + +func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.userSets[userID] == nil { + s.userSets[userID] = make(map[string]struct{}) + } + s.userSets[userID][tokenHash] = struct{}{} + return nil +} + +func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.families[familyID] == nil { + s.families[familyID] = make(map[string]struct{}) + } + s.families[familyID][tokenHash] = struct{}{} + return nil +} + +func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + tokenSet := s.userSets[userID] + out := make([]string, 0, len(tokenSet)) + for tokenHash := range tokenSet { + out = append(out, tokenHash) + } + return out, nil +} + +func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + tokenSet := s.families[familyID] + out := make([]string, 0, len(tokenSet)) + for tokenHash := range tokenSet { + out = append(out, tokenHash) + } + return out, nil +} + +func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.families[familyID][tokenHash] + return ok, nil +} + +type emailBindUserRepoStub struct { + mu sync.Mutex + usersByID map[int64]*service.User + usersByEmail map[string]*service.User +} + +func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub { + cloned := cloneEmailBindUser(user) + return &emailBindUserRepoStub{ + usersByID: map[int64]*service.User{ + cloned.ID: cloned, + }, + usersByEmail: map[string]*service.User{ + cloned.Email: cloned, + }, + } +} + +func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil } + +func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) { + s.mu.Lock() + defer s.mu.Unlock() + user, ok := s.usersByID[id] + if !ok { + return nil, service.ErrUserNotFound + } + return cloneEmailBindUser(user), nil +} + +func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) { + s.mu.Lock() + defer s.mu.Unlock() + user, ok := s.usersByEmail[email] + if !ok { + return nil, service.ErrUserNotFound + } + return cloneEmailBindUser(user), nil +} + +func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error { + s.mu.Lock() + defer s.mu.Unlock() + existing, ok := s.usersByID[user.ID] + if !ok { + return service.ErrUserNotFound + } + delete(s.usersByEmail, existing.Email) + cloned := cloneEmailBindUser(user) + s.usersByID[user.ID] = cloned + s.usersByEmail[cloned.Email] = cloned + return nil +} + +func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil } + +func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected DeleteUserAvatar call") +} + +func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error { + return nil +} + +func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } + +func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.usersByEmail[email] + return ok, nil +} + +func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} + +func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + return nil +} + +func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func cloneEmailBindUser(user *service.User) *service.User { + if user == nil { + return nil + } + cloned := *user + return &cloned +} diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2233e427eb7f479d68a11ef424a148b6575069be --- /dev/null +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -0,0 +1,482 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type authIdentityDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + +type flakyAuthIdentityDefaultSubAssignerStub struct { + failuresRemaining int + calls []*service.AssignSubscriptionInput +} + +func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + if s.failuresRemaining > 0 { + s.failuresRemaining-- + return nil, false, errors.New("temporary assign failure") + } + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + +type authIdentitySettingRepoStub struct { + values map[string]string +} + +func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil +} + +func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func newAuthServiceWithEnt( + t *testing.T, + settings map[string]string, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-auth-identity-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ + values: settings, + }, cfg) + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner) + return svc, repo, client +} + +func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + token, user, err := svc.Register(ctx, "user@example.com", "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, user) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "email", storedUser.SignupSource) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("user@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.NotNil(t, identity.VerifiedAt) +} + +func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("login@example.com"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetBalance(1). + SetConcurrency(1). + Save(ctx) + require.NoError(t, err) + + old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) + _, err = client.User.UpdateOneID(user.ID). + SetLastLoginAt(old). + SetLastActiveAt(old). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + require.True(t, storedUser.LastLoginAt.Equal(old)) + require.True(t, storedUser.LastActiveAt.Equal(old)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) + ctx := context.Background() + + user := &service.User{ + Email: "record@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("record@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("legacy@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`, + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "5", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("merged-first-bind@example.com"). + SetUsername("merged-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("bound@example.com"). + SetUsername("bound-user"). + SetPasswordHash(passwordHash). + SetBalance(2). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("bound@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "preexisting"}). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 2.0, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) { + assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("retry-first-bind@example.com"). + SetUsername("retry-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var count int + rows, err := client.QueryContext( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + userID, + providerType, + grantReason, + ) + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&count)) + require.NoError(t, rows.Err()) + return count +} diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go deleted file mode 100644 index 0472e06c72d7809d55d3f227bbd0bafc778880cd..0000000000000000000000000000000000000000 --- a/backend/internal/service/auth_service_pending_oauth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build unit - -package service - -import ( - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" -) - -func newAuthServiceForPendingOAuthTest() *AuthService { - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret-pending-oauth", - ExpireHour: 1, - }, - } - return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) -} - -// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 -func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - require.NotEmpty(t, token) - - email, username, err := svc.VerifyPendingOAuthToken(token) - require.NoError(t, err) - require.Equal(t, "user@example.com", email) - require.Equal(t, "alice", username) -} - -// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - // 签发一个普通 access token(JWTClaims,无 Purpose 字段) - accessToken, err := svc.GenerateToken(&User{ - ID: 1, - Email: "user@example.com", - Role: RoleUser, - }) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(accessToken) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "some_other_purpose", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "", // 旧 token 无此字段,反序列化后为零值 - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - past := time.Now().Add(-1 * time.Hour) - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(past), - IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { - other := NewAuthService(nil, nil, nil, nil, &config.Config{ - JWT: config.JWTConfig{Secret: "other-secret"}, - }, nil, nil, nil, nil, nil, nil) - - token, err := other.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - - svc := newAuthServiceForPendingOAuthTest() - _, _, err = svc.VerifyPendingOAuthToken(token) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - giant := make([]byte, maxTokenLength+1) - for i := range giant { - giant[i] = 'a' - } - _, _, err := svc.VerifyPendingOAuthToken(string(giant)) - require.ErrorIs(t, err, ErrInvalidToken) -} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 103bafe709943160f2d0aefc181a88c1744149d9..dbd18a20a979c4b968669d3bf76b72daab9fe94d 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error { } func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") + if s.err != nil { + return nil, s.err + } + result := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + result[key] = v + } + } + return result, nil } func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { @@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct { err error } +type refreshTokenCacheStub struct{} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) { + return nil, ErrRefreshTokenNotFound +} + +func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -322,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") @@ -469,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) service.defaultSubAssigner = assigner @@ -484,3 +537,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, 7, assigner.calls[1].ValidityDays) } + +func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) { + repo := &userRepoStub{nextID: 52} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`, + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 12.5, user.Balance) + require.Equal(t, 7, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 53} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "99", + SettingKeyAuthSourceDefaultEmailConcurrency: "88", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-global@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 3.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 54} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "9.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-merged@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 9.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { + repo := &userRepoStub{nextID: 61} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`, + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Equal(t, int64(61), user.ID) + require.Equal(t, 21.75, user.Balance) + require.Equal(t, 9, user.Concurrency) + require.Len(t, repo.created, 1) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(22), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) { + existing := &User{ + ID: 88, + Email: "linuxdo-123@linuxdo-connect.invalid", + Username: "existing-linuxdo", + Role: RoleUser, + Status: StatusActive, + Balance: 4, + Concurrency: 1, + TokenVersion: 2, + } + repo := &userRepoStub{user: existing} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.Equal(t, existing.ID, user.ID) + require.Equal(t, 4.0, user.Balance) + require.Equal(t, 1, user.Concurrency) + require.Empty(t, repo.created) + require.Empty(t, assigner.calls) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index f2ad0a3d05f3a06b86548537a97f1bf5c3686fab..4e695eb9bd0b15093f9c5dda7ea5fbec51425d87 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -20,6 +20,9 @@ import ( var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") + // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 + ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") + ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -87,6 +90,8 @@ type BillingCacheService struct { userRepo UserRepository subRepo UserSubscriptionRepository apiKeyRateLimitLoader apiKeyRateLimitLoader + userRPMCache UserRPMCache + userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker @@ -104,12 +109,22 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + userRPMCache UserRPMCache, + userGroupRateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, apiKeyRateLimitLoader: apiKeyRepo, + userRPMCache: userRPMCache, + userGroupRateRepo: userGroupRateRepo, cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) @@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } } + // RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。 + if err := s.checkRPM(ctx, user, group); err != nil { + return err + } + + return nil +} + +// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝: +// +// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。 +// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。 +// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。 +// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。 +// +// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。 +// Redis 故障一律 fail-open(打 warning,不阻塞业务)。 +func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error { + if s == nil || s.userRPMCache == nil || user == nil { + return nil + } + + // ── 第一层:分组级检查(override 或 group.rpm_limit) ── + if group != nil { + // 解析 override:优先从 auth cache snapshot,nil 时回退 DB。 + var override *int + if user.UserGroupRPMOverride != nil { + override = user.UserGroupRPMOverride + } else if s.userGroupRateRepo != nil { + dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm override lookup failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + } else { + override = dbOverride + } + } + + if override != nil { + // override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。 + if *override > 0 { + count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if incErr != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (override) failed for user=%d group=%d: %v", + user.ID, group.ID, incErr, + ) + // fail-open + } else if count > *override { + return ErrGroupRPMExceeded + } + } + // override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。 + } else if group.RPMLimit > 0 { + // 无 override,检查 group.rpm_limit。 + count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (group) failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + // fail-open + } else if count > group.RPMLimit { + return ErrGroupRPMExceeded + } + } + } + + // ── 第二层:用户级全局硬上限(始终生效) ── + if user.RPMLimit > 0 { + count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (user) failed for user=%d: %v", + user.ID, err, + ) + return nil // fail-open + } + if count > user.RPMLimit { + return ErrUserRPMExceeded + } + } + return nil } diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go new file mode 100644 index 0000000000000000000000000000000000000000..de66136fcd035fe5d976e77bd3d0e6b9619fbb51 --- /dev/null +++ b/backend/internal/service/billing_cache_service_rpm_test.go @@ -0,0 +1,253 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。 +type userRPMCacheStub struct { + userGroupCalls int32 + userCalls int32 + + userGroupCounts []int // 依次返回的计数值 + userGroupErr error + userCounts []int + userErr error +} + +func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1 + if s.userGroupErr != nil { + return 0, s.userGroupErr + } + if idx < len(s.userGroupCounts) { + return s.userGroupCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1 + if s.userErr != nil { + return 0, s.userErr + } + if idx < len(s.userCounts) { + return s.userCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + return 0, nil +} + +func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) { + return 0, nil +} + +// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。 +type rpmOverrideRepoStub struct { + UserGroupRateRepository + + override *int + err error + calls int32 +} + +func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + atomic.AddInt32(&s.calls, 1) + if s.err != nil { + return nil, s.err + } + return s.override, nil +} + +func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService { + t.Helper() + // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。 + // 我们只直接测 checkRPM。 + svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}) + t.Cleanup(svc.Stop) + return svc +} + +func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) { + override := 2 + // user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) + + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数") + // 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用") + require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) { + override := 100 // override 很高 + // user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) { + zero := 0 + // user 计数: 依次返回 1..6 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 5} + group := &Group{ID: 10, RPMLimit: 100} + + // override=0 跳过分组计数,但 user.RPMLimit=5 仍生效 + for i := 0; i < 5; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1) + } + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, + "override=0 跳过分组但 user 全局上限仍应生效") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器") + require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) { + zero := 0 + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} // user 也不限 + group := &Group{ID: 10, RPMLimit: 100} + + for i := 0; i < 50; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数") +} + +func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) { + // user-group 计数: 5, 6;user 计数: 默认 1(不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超 + group := &Group{ID: 10, RPMLimit: 5} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超 + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5 + + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls)) + // 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回") +} + +func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) { + cache := &userRPMCacheStub{userGroupCounts: []int{3}} + repo := &rpmOverrideRepoStub{err: errors.New("db down")} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 10} + + // override 查询失败后应继续尝试 group 分支(不直接拒绝) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + group := &Group{ID: 10, RPMLimit: 0} // 分组未设限 + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 0} + + for i := 0; i < 10; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) { + cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 5} + + // Redis 故障时应 fail-open,不拒绝请求 + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) +} + +func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + + // 无 group(纯用户级限流场景),不应查询 rpm_override。 + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10})) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls)) +} diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 4a8b8f03e570c67319b946f82a6ba3f2700c52e6..962becf0612e6f827e8750dc4240b55883b6d2a7 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -86,13 +86,21 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, return &User{ID: id, Balance: s.balance}, nil } +func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + +func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { + return nil +} + func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { cache := &billingCacheMissStub{} userRepo := &balanceLoadUserRepoStub{ delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 7d7045e2920073a65cf9d37fda0799cb12b9c768..849e24b814570b7e634c7fe16cf6cdbc80e442bc 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 32a54cbedfb63b055a32754e63e10d0b0b8baaaa..392b3e0ba9fbc2e59bf2f639322265ce3df4c851 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() { SupportsCacheBreakdown: false, } - // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) - s.fallbackPrices["gpt-5.1"] = &ModelPricing{ - InputPricePerToken: 1.25e-6, // $1.25 per MTok - InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok - OutputPricePerToken: 10e-6, // $10 per MTok - OutputPricePerTokenPriority: 20e-6, // $20 per MTok - CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok - CacheReadPricePerToken: 0.125e-6, - CacheReadPricePerTokenPriority: 0.25e-6, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.4(业务指定价格) s.fallbackPrices["gpt-5.4"] = &ModelPricing{ InputPricePerToken: 2.5e-6, // $2.5 per MTok @@ -228,18 +217,15 @@ func (s *BillingService) initFallbackPricing() { LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, } + // GPT-5.5 暂无独立定价,回退到 GPT-5.4 + s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"] + s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ InputPricePerToken: 7.5e-7, OutputPricePerToken: 4.5e-6, CacheReadPricePerToken: 7.5e-8, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ - InputPricePerToken: 2e-7, - OutputPricePerToken: 1.25e-6, - CacheReadPricePerToken: 2e-8, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -251,8 +237,8 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.35e-6, SupportsCacheBreakdown: false, } - // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 - s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + // Codex 族兜底统一按 GPT-5.3 Codex 价格计费 + s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{ InputPricePerToken: 1.5e-6, // $1.5 per MTok InputPricePerTokenPriority: 3e-6, // $3 per MTok OutputPricePerToken: 12e-6, // $12 per MTok @@ -262,17 +248,6 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.3e-6, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ - InputPricePerToken: 1.75e-6, - InputPricePerTokenPriority: 3.5e-6, - OutputPricePerToken: 14e-6, - OutputPricePerTokenPriority: 28e-6, - CacheCreationPricePerToken: 1.75e-6, - CacheReadPricePerToken: 0.175e-6, - CacheReadPricePerTokenPriority: 0.35e-6, - SupportsCacheBreakdown: false, - } - s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -316,22 +291,16 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { normalized := normalizeCodexModel(modelLower) switch normalized { + case "gpt-5.5": + return s.fallbackPrices["gpt-5.5"] case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] - case "gpt-5.4-nano": - return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": return s.fallbackPrices["gpt-5.2"] - case "gpt-5.2-codex": - return s.fallbackPrices["gpt-5.2-codex"] - case "gpt-5.3-codex": + case "gpt-5.3-codex", "gpt-5.3-codex-spark": return s.fallbackPrices["gpt-5.3-codex"] - case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": - return s.fallbackPrices["gpt-5.1-codex"] - case "gpt-5.1": - return s.fallbackPrices["gpt-5.1"] } } @@ -448,8 +417,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, }) } - if input.RateMultiplier <= 0 { - input.RateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。 + if input.RateMultiplier < 0 { + input.RateMultiplier = 0 } var breakdown *CostBreakdown @@ -493,8 +463,9 @@ func (s *BillingService) computeTokenBreakdown( rateMultiplier float64, serviceTier string, applyLongCtx bool, ) *CostBreakdown { - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。 + if rateMultiplier < 0 { + rateMultiplier = 0 } inputPrice := pricing.InputPricePerToken @@ -665,8 +636,14 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens } func isOpenAIGPT54Model(model string) bool { - normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) - return normalized == "gpt-5.4" + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel + // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + normalized := normalizeCodexModel(trimmed) + return normalized == "gpt-5.4" || normalized == "gpt-5.5" } // CalculateCostWithConfig 使用配置中的默认倍率计算费用 @@ -831,9 +808,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag // 计算总费用 totalCost := unitPrice * float64(imageCount) - // 应用倍率 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣) + if rateMultiplier < 0 { + rateMultiplier = 0 } actualCost := totalCost * rateMultiplier diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index fa90f6bba551ff692c07c84a029c7b3635abd398..8d3ca9877871fe4172e3da764ac7f3fef2eaf06d 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { require.Equal(t, 0.0, cost.ActualCost) } -// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 +// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费 +// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。 func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) require.InDelta(t, 0.201, cost.TotalCost, 0.0001) - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go new file mode 100644 index 0000000000000000000000000000000000000000..83788196110e59cef2e4ba0dc54cb350ce30019a --- /dev/null +++ b/backend/internal/service/billing_service_rate_multiplier_test.go @@ -0,0 +1,63 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被 +// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。 +func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 // ActualCost / TotalCost + }{ + {"negative clamped to 0", -1.5, 0}, + {"zero passes through as 0 (defense in depth)", 0, 0}, + {"positive 2x applied", 2.0, 2.0}, + {"positive 0.5x applied", 0.5, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier) + require.NoError(t, err) + require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero") + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} + +// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径 +// 同样遵循"负数 → 0"语义。 +func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + price := 0.04 + cfg := &ImagePriceConfig{Price1K: &price} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 + }{ + {"negative clamped to 0", -0.5, 0}, + {"zero passes through", 0, 0}, + {"positive 3x applied", 3.0, 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier) + require.NotNil(t, cost) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 2cf134e2baa4c0983c01851e2949da7c08b9e22c..222abd6990dc324d0e07afada2d534b064a48131 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) } -func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) -} - -func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) -} - func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { svc := newTestBillingService() @@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { require.Contains(t, err.Error(), "pricing not found") } -func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.1") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) -} - func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { svc := newTestBillingService() @@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { require.Zero(t, pricing.LongContextInputThreshold) } -func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.4-nano") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12) - require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12) - require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12) - require.Zero(t, pricing.LongContextInputThreshold) -} - func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { svc := newTestBillingService() @@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, - {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, - {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, - {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, - {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6}, + {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6}, + {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6}, + {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6}, {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, {name: "non supported family", model: "qwen-max", expectNilPricing: true}, } diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go index 694c338418d188b7684bed18706f6ac684fa1872..e6a92d1a8c68723eda5cc8fdd8388c619b4231cd 100644 --- a/backend/internal/service/billing_service_unified_test.go +++ b/backend/internal/service/billing_service_unified_test.go @@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { require.Equal(t, string(BillingModeImage), cost.BillingMode) } -func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为: +// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。 +func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} - costZero, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 0, // should default to 1.0 - Resolver: resolver, - }) - require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, - RateMultiplier: 1.0, + RateMultiplier: 0, Resolver: resolver, }) require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } -func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为: +// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。 +func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} - costNeg, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, @@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 93beb97277a89adcba30e1b5bfd33fe3c8d461cc..158bf8a31bff53d9229bd25c855882057915ae9a 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -111,6 +111,18 @@ func (c *Channel) IsActive() bool { return c.Status == StatusActive } +// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。 +// 作为 *Channel 的实体方法集中管理默认值,service 层只需在 Channel 进入内存 +// (缓存装填、repo 读出)时调用一次,下游读路径就无需重复兜底。 +func (c *Channel) normalizeBillingModelSource() { + if c == nil { + return + } + if c.BillingModelSource == "" { + c.BillingModelSource = BillingModelSourceChannelMapped + } +} + // GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 // 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。 func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { @@ -345,3 +357,209 @@ type ChannelUsageFields struct { BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" ModelMappingChain string // 映射链描述,如 "a→b→c" } + +// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户) +type SupportedModel struct { + Name string // 用户侧模型名 + Platform string // 所属平台 + Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价) +} + +// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。 +const wildcardSuffix = "*" + +// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。 +// +// "claude-opus-*" → ("claude-opus-", true) +// "claude-opus-4" → ("claude-opus-4", false) +// "*" → ("", true) +// +// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。 +func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) { + if strings.HasSuffix(pattern, wildcardSuffix) { + return strings.TrimSuffix(pattern, wildcardSuffix), true + } + return pattern, false +} + +// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。 +// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。 +func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing { + if c == nil { + return nil + } + modelLower := strings.ToLower(model) + for i := range c.ModelPricing { + if c.ModelPricing[i].Platform != platform { + continue + } + for _, m := range c.ModelPricing[i].Models { + if strings.ToLower(m) == modelLower { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + return nil +} + +// platformPricingIndex 是单个平台下定价信息的复合索引。 +// 一次扫描即可同时支持精确查找(exact 分支)与有序遍历(wildcard 分支), +// 避免 SupportedModels 对每个平台重复扫描定价列表。 +// +// byLower 与 names/originalCase 共享同一套去重规则:以 lower-case 模型名为 key, +// 首个命中保留其原始大小写。names 维持按定价行扫描顺序的稳定迭代。 +type platformPricingIndex struct { + byLower map[string]*ChannelModelPricing // lowercased model name → pricing (Clone'd) + originalCase map[string]string // lowercased model name → original-case model name + names []string // priced model names in their ORIGINAL case, insertion-ordered, deduped case-insensitively (first wins) +} + +// buildPricingIndex 对渠道的定价列表做一次扫描,按 platform 聚合为查找索引。 +// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。 +// 通配符后缀条目(如 "claude-*")不被索引(它们是模式,不是具体模型名)。 +// 同一平台中以大小写不敏感方式去重,先出现者保留原始大小写。 +func buildPricingIndex(pricings []ChannelModelPricing) map[string]*platformPricingIndex { + idx := make(map[string]*platformPricingIndex) + for i := range pricings { + p := pricings[i] + pidx, ok := idx[p.Platform] + if !ok { + pidx = &platformPricingIndex{ + byLower: make(map[string]*ChannelModelPricing), + originalCase: make(map[string]string), + names: make([]string, 0), + } + idx[p.Platform] = pidx + } + for _, m := range p.Models { + if _, wild := splitWildcardSuffix(m); wild { + continue + } + lower := strings.ToLower(m) + if _, exists := pidx.byLower[lower]; exists { + continue // 首个命中胜出(case-insensitive 去重后第一个定价 / 第一个原始大小写) + } + cp := pricings[i].Clone() + pidx.byLower[lower] = &cp + pidx.originalCase[lower] = m + pidx.names = append(pidx.names, m) + } + } + return idx +} + +// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。 +// +// 算法(mapping ∪ pricing 并联): +// +// - Pass A(mapping):遍历 ModelMapping +// - 精确 src → target:显示名 = src(用户视角),定价用 target 在同 platform 定价里查 +// (mapping 改写后实际计费的是 target;这是用户感知的"实际花费")。 +// target 为空或为通配符时退化为按 src 自查。 +// - 通配符 src(如 "claude-3-*"):用同 platform 定价里前缀匹配的模型作为候选展开, +// 每个候选用自身定价(通配符场景一般是 passthrough,target 通常也是通配符)。 +// - "*" 单独 mapping key 走通配符分支(前缀为空 → 全展开)。 +// - Pass B(pricing-only):遍历 ModelPricing 中所有非通配符模型,对未在 Pass A 添加过的 +// 补齐——显示名 = 定价模型名,定价 = 自身(这是关键修复:定价存在即代表渠道支持该模型, +// 即使没配映射)。 +// +// 显示名命中定价时使用**定价的原始大小写**(定价是模型身份的事实来源)。 +// 按 (Platform, Name) 稳定排序,按 (Platform, lowercase(Name)) 去重,先到者胜出。 +// +// 注意:定价仅在 channel.ModelPricing 内查找——全局 LiteLLM 回落由调用方 +// (`ChannelService.ListAvailable`)在合成展示数据时叠加。 +func (c *Channel) SupportedModels() []SupportedModel { + if c == nil { + return nil + } + if len(c.ModelMapping) == 0 && len(c.ModelPricing) == 0 { + return nil + } + + idx := buildPricingIndex(c.ModelPricing) + + type dedupKey struct { + platform string + name string + } + seen := make(map[dedupKey]struct{}) + result := make([]SupportedModel, 0) + + // lookup 在 platform pricing index 中按精确名查定价,命中时返回定价大小写。 + lookup := func(pidx *platformPricingIndex, name string) (display string, pricing *ChannelModelPricing) { + if pidx == nil || name == "" { + return name, nil + } + lower := strings.ToLower(name) + if p, ok := pidx.byLower[lower]; ok { + return pidx.originalCase[lower], p + } + return name, nil + } + + add := func(platform, displayName string, pricing *ChannelModelPricing) { + key := dedupKey{platform: platform, name: strings.ToLower(displayName)} + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + result = append(result, SupportedModel{ + Name: displayName, + Platform: platform, + Pricing: pricing, + }) + } + + // Pass A:从 mapping 展开 + for platform, mapping := range c.ModelMapping { + if len(mapping) == 0 { + continue + } + pidx := idx[platform] + for src, target := range mapping { + prefix, isWild := splitWildcardSuffix(src) + if isWild { + if pidx == nil { + continue + } + prefixLower := strings.ToLower(prefix) + for _, candidate := range pidx.names { + if strings.HasPrefix(strings.ToLower(candidate), prefixLower) { + display, pricing := lookup(pidx, candidate) + add(platform, display, pricing) + } + } + continue + } + // 精确 mapping:定价按 target 查;target 缺失/通配则退化按 src 查 + pricingKey := target + if pricingKey == "" { + pricingKey = src + } + if _, targetWild := splitWildcardSuffix(pricingKey); targetWild { + pricingKey = src + } + _, pricing := lookup(pidx, pricingKey) + // 显示名优先用 src 在定价里的原始大小写(若 src 本身是个定价模型名) + displayName, _ := lookup(pidx, src) + add(platform, displayName, pricing) + } + } + + // Pass B:从 pricing 补齐 mapping 未覆盖的具体模型(修复"定价存在但没配映射 → 不显示") + for platform, pidx := range idx { + for _, name := range pidx.names { + display, pricing := lookup(pidx, name) + add(platform, display, pricing) + } + } + + sort.SliceStable(result, func(i, j int) bool { + if result[i].Platform != result[j].Platform { + return result[i].Platform < result[j].Platform + } + return result[i].Name < result[j].Name + }) + return result +} diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go new file mode 100644 index 0000000000000000000000000000000000000000..815730e320723085f73ac5ddcc059f9de87c99af --- /dev/null +++ b/backend/internal/service/channel_available.go @@ -0,0 +1,149 @@ +package service + +import ( + "context" + "fmt" + "sort" + "strings" +) + +// AvailableGroupRef 渠道视图中关联分组的简要信息。 +// +// 用户侧「可用渠道」页面据此展示:专属分组 vs 公开分组(IsExclusive)、 +// 订阅 vs 标准(SubscriptionType)、默认倍率(RateMultiplier)。用户专属倍率 +// 不在这里暴露,前端自己通过 /groups/rates 拉取,和 API 密钥页面保持一致。 +type AvailableGroupRef struct { + ID int64 + Name string + Platform string + SubscriptionType string + RateMultiplier float64 + IsExclusive bool +} + +// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 + +// 关联的分组 + 推导出的支持模型列表(无通配符)。 +type AvailableChannel struct { + ID int64 + Name string + Description string + Status string + BillingModelSource string + RestrictModels bool + Groups []AvailableGroupRef + SupportedModels []SupportedModel +} + +// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。 +// +// 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。 +// 对于渠道未配置定价的模型,进一步用 PricingService 的全局 LiteLLM 数据合成 +// 一份展示用定价,让用户看到默认价格而非"未配置"。 +// +// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中 +// 的分组(已停用或删除)会被忽略。 +// +// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast, +// 避免静默掩盖注入缺失。 +func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) { + channels, err := s.repo.ListAll(ctx) + if err != nil { + return nil, fmt.Errorf("list channels: %w", err) + } + + groups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, fmt.Errorf("list active groups: %w", err) + } + groupByID := make(map[int64]AvailableGroupRef, len(groups)) + for i := range groups { + g := groups[i] + groupByID[g.ID] = AvailableGroupRef{ + ID: g.ID, + Name: g.Name, + Platform: g.Platform, + SubscriptionType: g.SubscriptionType, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + } + } + + out := make([]AvailableChannel, 0, len(channels)) + for i := range channels { + ch := &channels[i] + groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs)) + for _, gid := range ch.GroupIDs { + if ref, ok := groupByID[gid]; ok { + groups = append(groups, ref) + } + } + sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name }) + + ch.normalizeBillingModelSource() + + supported := ch.SupportedModels() + s.fillGlobalPricingFallback(supported) + + out = append(out, AvailableChannel{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + BillingModelSource: ch.BillingModelSource, + RestrictModels: ch.RestrictModels, + Groups: groups, + SupportedModels: supported, + }) + } + + sort.SliceStable(out, func(i, j int) bool { + return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) + }) + return out, nil +} + +// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份 +// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。 +// +// 当 s.pricingService 为 nil(测试场景),跳过回落。 +func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) { + if s.pricingService == nil { + return + } + for i := range models { + if models[i].Pricing != nil { + continue + } + lp := s.pricingService.GetModelPricing(models[i].Name) + if lp == nil { + continue + } + models[i].Pricing = synthesizePricingFromLiteLLM(lp) + } +} + +// synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态, +// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到 +// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。 +// +// LiteLLM 中字段 0 视为未配置,不带入展示。 +func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing) *ChannelModelPricing { + if lp == nil { + return nil + } + return &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: nonZeroPtr(lp.InputCostPerToken), + OutputPrice: nonZeroPtr(lp.OutputCostPerToken), + CacheWritePrice: nonZeroPtr(lp.CacheCreationInputTokenCost), + CacheReadPrice: nonZeroPtr(lp.CacheReadInputTokenCost), + ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken), + } +} + +func nonZeroPtr(v float64) *float64 { + if v == 0 { + return nil + } + return &v +} diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8be70ceb62e56e4abce643467e6e01b381e21588 --- /dev/null +++ b/backend/internal/service/channel_available_test.go @@ -0,0 +1,177 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub, +// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。 +// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。 +// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。 +type stubGroupRepoForAvailable struct { + activeGroups []Group + listActiveErr error + listActiveCalls int +} + +func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) { + s.listActiveCalls++ + if s.listActiveErr != nil { + return nil, s.listActiveErr + } + return s.activeGroups, nil +} + +func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil } +func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil } +func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil } +func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, nil +} +func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil +} +func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} +func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} +func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + +// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels, +// groupRepo 由参数决定。传入空 stub 表示「活跃分组列表为空」。 +func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService { + repo := &mockChannelRepository{ + listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil }, + } + return NewChannelService(repo, groupRepo, nil, nil) +} + +func TestListAvailable_EmptyActiveGroups_NoGroupsAttached(t *testing.T) { + // 活跃分组列表为空时,渠道的 Groups 应为空切片,不报错。 + channels := []Channel{{ + ID: 1, + Name: "chA", + Status: StatusActive, + GroupIDs: []int64{10, 20}, + }} + svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{}) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + require.Empty(t, out[0].Groups) +} + +func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) { + // 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。 + channels := []Channel{{ + ID: 1, + Name: "chA", + Status: StatusActive, + GroupIDs: []int64{1, 99}, + }} + groupRepo := &stubGroupRepoForAvailable{ + activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}}, + } + svc := newAvailableChannelService(channels, groupRepo) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + require.Len(t, out[0].Groups, 1) + require.Equal(t, int64(1), out[0].Groups[0].ID) +} + +func TestListAvailable_SortedByName(t *testing.T) { + channels := []Channel{ + {ID: 1, Name: "beta"}, + {ID: 2, Name: "Alpha"}, + {ID: 3, Name: "charlie"}, + } + svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{}) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 3) + require.Equal(t, "Alpha", out[0].Name) + require.Equal(t, "beta", out[1].Name) + require.Equal(t, "charlie", out[2].Name) +} + +func TestListAvailable_ListAllErrorPropagates(t *testing.T) { + // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。 + sentinel := errors.New("list-all-boom") + repo := &mockChannelRepository{ + listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel }, + } + groupRepo := &stubGroupRepoForAvailable{} + svc := NewChannelService(repo, groupRepo, nil, nil) + out, err := svc.ListAvailable(context.Background()) + require.Nil(t, out) + require.ErrorIs(t, err, sentinel) + require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v") + require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive") +} + +func TestListAvailable_ListActiveErrorPropagates(t *testing.T) { + // groupRepo.ListActive 返回错误时 ListAvailable 应直接返回包装后的错误。 + sentinel := errors.New("list-active-boom") + svc := newAvailableChannelService( + []Channel{{ID: 1, Name: "chA"}}, + &stubGroupRepoForAvailable{listActiveErr: sentinel}, + ) + out, err := svc.ListAvailable(context.Background()) + require.Nil(t, out) + require.ErrorIs(t, err, sentinel) + require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v") +} + +func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) { + // 渠道 BillingModelSource 为空时应回填为 BillingModelSourceChannelMapped, + // 显式值应原样保留(由 service 层统一处理,避免各 handler 重复默认逻辑)。 + channels := []Channel{ + {ID: 1, Name: "empty", BillingModelSource: ""}, + {ID: 2, Name: "explicit", BillingModelSource: BillingModelSourceUpstream}, + } + svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{}) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 2) + + // 按 Name 查找,避免依赖排序副作用。 + byName := make(map[string]string, len(out)) + for _, ch := range out { + byName[ch.Name] = ch.BillingModelSource + } + require.Equal(t, BillingModelSourceChannelMapped, byName["empty"]) + require.Equal(t, BillingModelSourceUpstream, byName["explicit"]) +} diff --git a/backend/internal/service/channel_monitor_aggregator.go b/backend/internal/service/channel_monitor_aggregator.go new file mode 100644 index 0000000000000000000000000000000000000000..09020f5fa42770a14ce1b41881f6fba517b9b422 --- /dev/null +++ b/backend/internal/service/channel_monitor_aggregator.go @@ -0,0 +1,292 @@ +package service + +import ( + "context" + "fmt" + "log/slog" +) + +// 渠道监控聚合层:把 latest + availability 拼成 admin/user 视图所需的 summary / detail。 +// 所有方法都遵守"失败仅日志,返回零值"的原则,避免 N+1 查询失败拖垮列表渲染。 + +// BatchMonitorStatusSummary 批量聚合多个监控的 latest + 7d 可用率(admin/user list 用,消除 N+1)。 +// 失败时返回空 map,错误仅日志,不影响列表渲染。 +// +// 参数: +// - ids: 要聚合的 monitor ID 列表 +// - primaryByID: monitor ID -> primary model(用于读 7d 可用率与 latest 状态) +// - extrasByID: monitor ID -> extra models 列表(用于读 latest 状态填充 ExtraModels) +func (s *ChannelMonitorService) BatchMonitorStatusSummary( + ctx context.Context, + ids []int64, + primaryByID map[int64]string, + extrasByID map[int64][]string, +) map[int64]MonitorStatusSummary { + out := make(map[int64]MonitorStatusSummary, len(ids)) + if len(ids) == 0 { + return out + } + latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids) + if err != nil { + slog.Warn("channel_monitor: batch load latest failed", "error", err) + latestMap = map[int64][]*ChannelMonitorLatest{} + } + availMap, err := s.repo.ComputeAvailabilityForMonitors(ctx, ids, monitorAvailability7Days) + if err != nil { + slog.Warn("channel_monitor: batch compute availability failed", "error", err) + availMap = map[int64][]*ChannelMonitorAvailability{} + } + + for _, id := range ids { + out[id] = buildStatusSummary( + indexLatestByModel(latestMap[id]), + indexAvailabilityByModel(availMap[id]), + primaryByID[id], + extrasByID[id], + ) + } + return out +} + +// ListUserView 用户只读视图:列出所有 enabled 监控的概览。 +// 使用批量聚合接口避免 N+1: +// +// 1 次查 monitors; +// 1 次批量 latest(含 ping_latency_ms); +// 1 次批量 7d availability; +// 1 次批量 timeline(主模型最近 N 条)。 +func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonitorView, error) { + monitors, err := s.repo.ListEnabled(ctx) + if err != nil { + return nil, fmt.Errorf("list enabled monitors: %w", err) + } + if len(monitors) == 0 { + return []*UserMonitorView{}, nil + } + + ids, primaryByID, extrasByID := collectMonitorIndexes(monitors) + summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID) + latestMap := s.batchLatest(ctx, ids) + timelineMap := s.batchTimeline(ctx, ids, primaryByID) + + views := make([]*UserMonitorView, 0, len(monitors)) + for _, m := range monitors { + primaryLatest := pickLatest(latestMap[m.ID], m.PrimaryModel) + views = append(views, buildUserViewFromSummary(m, summaries[m.ID], primaryLatest, timelineMap[m.ID])) + } + return views, nil +} + +// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。 +func collectMonitorIndexes(monitors []*ChannelMonitor) ([]int64, map[int64]string, map[int64][]string) { + ids := make([]int64, 0, len(monitors)) + primaryByID := make(map[int64]string, len(monitors)) + extrasByID := make(map[int64][]string, len(monitors)) + for _, m := range monitors { + ids = append(ids, m.ID) + primaryByID[m.ID] = m.PrimaryModel + extrasByID[m.ID] = m.ExtraModels + } + return ids, primaryByID, extrasByID +} + +// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。 +func (s *ChannelMonitorService) batchLatest(ctx context.Context, ids []int64) map[int64][]*ChannelMonitorLatest { + latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids) + if err != nil { + slog.Warn("channel_monitor: user view batch latest failed", "error", err) + return map[int64][]*ChannelMonitorLatest{} + } + return latestMap +} + +// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。 +func (s *ChannelMonitorService) batchTimeline( + ctx context.Context, + ids []int64, + primaryByID map[int64]string, +) map[int64][]*ChannelMonitorHistoryEntry { + timelineMap, err := s.repo.ListRecentHistoryForMonitors(ctx, ids, primaryByID, monitorTimelineMaxPoints) + if err != nil { + slog.Warn("channel_monitor: user view batch timeline failed", "error", err) + return map[int64][]*ChannelMonitorHistoryEntry{} + } + return timelineMap +} + +// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。 +func pickLatest(rows []*ChannelMonitorLatest, model string) *ChannelMonitorLatest { + if model == "" { + return nil + } + for _, r := range rows { + if r.Model == model { + return r + } + } + return nil +} + +// GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。 +// 不暴露 api_key。 +func (s *ChannelMonitorService) GetUserDetail(ctx context.Context, id int64) (*UserMonitorDetail, error) { + m, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + if !m.Enabled { + return nil, ErrChannelMonitorNotFound + } + + latest, err := s.repo.ListLatestPerModel(ctx, id) + if err != nil { + return nil, fmt.Errorf("list latest per model: %w", err) + } + availMap, err := s.collectAvailabilityWindows(ctx, id) + if err != nil { + return nil, err + } + + models := mergeModelDetails(m, latest, availMap) + return &UserMonitorDetail{ + ID: m.ID, + Name: m.Name, + Provider: m.Provider, + GroupName: m.GroupName, + Models: models, + }, nil +} + +// collectAvailabilityWindows 一次性查询 7/15/30 天三个窗口,按模型组织。 +func (s *ChannelMonitorService) collectAvailabilityWindows(ctx context.Context, monitorID int64) (map[int]map[string]*ChannelMonitorAvailability, error) { + out := make(map[int]map[string]*ChannelMonitorAvailability, 3) + windows := []int{monitorAvailability7Days, monitorAvailability15Days, monitorAvailability30Days} + for _, w := range windows { + rows, err := s.repo.ComputeAvailability(ctx, monitorID, w) + if err != nil { + return nil, fmt.Errorf("compute availability %dd: %w", w, err) + } + out[w] = indexAvailabilityByModel(rows) + } + return out, nil +} + +// ---------- 纯函数 helper(无 IO,可在 batch / 单 monitor / detail 路径复用)---------- + +// indexLatestByModel 把 latest 切片按 model 索引(小工具,避免在 hot path 重复写)。 +func indexLatestByModel(rows []*ChannelMonitorLatest) map[string]*ChannelMonitorLatest { + m := make(map[string]*ChannelMonitorLatest, len(rows)) + for _, r := range rows { + m[r.Model] = r + } + return m +} + +// indexAvailabilityByModel 把 availability 切片按 model 索引。 +func indexAvailabilityByModel(rows []*ChannelMonitorAvailability) map[string]*ChannelMonitorAvailability { + m := make(map[string]*ChannelMonitorAvailability, len(rows)) + for _, r := range rows { + m[r.Model] = r + } + return m +} + +// buildStatusSummary 由 latest + availability 字典构造 MonitorStatusSummary。 +// 不做任何 IO,纯组装,便于在 batch 与单 monitor 路径复用。 +func buildStatusSummary( + latestByModel map[string]*ChannelMonitorLatest, + availByModel map[string]*ChannelMonitorAvailability, + primary string, + extras []string, +) MonitorStatusSummary { + summary := MonitorStatusSummary{ExtraModels: make([]ExtraModelStatus, 0, len(extras))} + if primary != "" { + if l, ok := latestByModel[primary]; ok { + summary.PrimaryStatus = l.Status + summary.PrimaryLatencyMs = l.LatencyMs + } + if a, ok := availByModel[primary]; ok { + summary.Availability7d = a.AvailabilityPct + } + } + for _, model := range extras { + entry := ExtraModelStatus{Model: model} + if l, ok := latestByModel[model]; ok { + entry.Status = l.Status + entry.LatencyMs = l.LatencyMs + } + summary.ExtraModels = append(summary.ExtraModels, entry) + } + return summary +} + +// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。 +// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。 +func buildUserViewFromSummary( + m *ChannelMonitor, + summary MonitorStatusSummary, + primaryLatest *ChannelMonitorLatest, + timelineEntries []*ChannelMonitorHistoryEntry, +) *UserMonitorView { + view := &UserMonitorView{ + ID: m.ID, + Name: m.Name, + Provider: m.Provider, + GroupName: m.GroupName, + PrimaryModel: m.PrimaryModel, + PrimaryStatus: summary.PrimaryStatus, + PrimaryLatencyMs: summary.PrimaryLatencyMs, + Availability7d: summary.Availability7d, + ExtraModels: summary.ExtraModels, + Timeline: buildTimelinePoints(timelineEntries), + } + if primaryLatest != nil { + view.PrimaryPingLatencyMs = primaryLatest.PingLatencyMs + } + return view +} + +// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。 +func buildTimelinePoints(entries []*ChannelMonitorHistoryEntry) []UserMonitorTimelinePoint { + out := make([]UserMonitorTimelinePoint, 0, len(entries)) + for _, e := range entries { + out = append(out, UserMonitorTimelinePoint{ + Status: e.Status, + LatencyMs: e.LatencyMs, + PingLatencyMs: e.PingLatencyMs, + CheckedAt: e.CheckedAt, + }) + } + return out +} + +// mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。 +// 复用 indexLatestByModel,避免在多处重复写 build map 逻辑。 +func mergeModelDetails( + m *ChannelMonitor, + latest []*ChannelMonitorLatest, + availMap map[int]map[string]*ChannelMonitorAvailability, +) []ModelDetail { + all := append([]string{m.PrimaryModel}, m.ExtraModels...) + latestByModel := indexLatestByModel(latest) + out := make([]ModelDetail, 0, len(all)) + for _, model := range all { + d := ModelDetail{Model: model} + if l, ok := latestByModel[model]; ok { + d.LatestStatus = l.Status + d.LatestLatencyMs = l.LatencyMs + } + if a, ok := availMap[monitorAvailability7Days][model]; ok { + d.Availability7d = a.AvailabilityPct + d.AvgLatency7dMs = a.AvgLatencyMs + } + if a, ok := availMap[monitorAvailability15Days][model]; ok { + d.Availability15d = a.AvailabilityPct + } + if a, ok := availMap[monitorAvailability30Days][model]; ok { + d.Availability30d = a.AvailabilityPct + } + out = append(out, d) + } + return out +} diff --git a/backend/internal/service/channel_monitor_challenge.go b/backend/internal/service/channel_monitor_challenge.go new file mode 100644 index 0000000000000000000000000000000000000000..e81a9e2a30f2b2fc3063005e7812d83c090ae767 --- /dev/null +++ b/backend/internal/service/channel_monitor_challenge.go @@ -0,0 +1,80 @@ +package service + +import ( + "fmt" + "math/rand/v2" + "regexp" + "strconv" +) + +// monitorChallengePromptTemplate 1:1 复刻 BingZi-233/check-cx 的 few-shot 模板。 +const monitorChallengePromptTemplate = `Calculate and respond with ONLY the number, nothing else. + +Q: 3 + 5 = ? +A: 8 + +Q: 12 - 7 = ? +A: 5 + +Q: %d %s %d = ? +A:` + +// monitorChallengeNumberRegex 提取响应中的所有整数(含负号)。 +var monitorChallengeNumberRegex = regexp.MustCompile(`-?\d+`) + +// monitorChallenge 一次 challenge 的 prompt + 期望答案。 +type monitorChallenge struct { + Prompt string + Expected string +} + +// generateChallenge 生成一次随机算术 challenge: +// - 随机两个 [monitorChallengeMin, monitorChallengeMax] 整数 +// - 50% 加 / 50% 减;减法用 max - min 保证非负 +// - 渲染 few-shot 模板 +// +// 不强求加密随机:math/rand/v2 足够分散,避免 crypto/rand 的开销。 +func generateChallenge() monitorChallenge { + a := randIntInRange(monitorChallengeMin, monitorChallengeMax) + b := randIntInRange(monitorChallengeMin, monitorChallengeMax) + + if rand.IntN(2) == 0 { //nolint:gosec // 仅用于生成测试问题,无安全影响 + // 加法 + return monitorChallenge{ + Prompt: fmt.Sprintf(monitorChallengePromptTemplate, a, "+", b), + Expected: strconv.Itoa(a + b), + } + } + + // 减法,保证非负 + hi, lo := a, b + if lo > hi { + hi, lo = lo, hi + } + return monitorChallenge{ + Prompt: fmt.Sprintf(monitorChallengePromptTemplate, hi, "-", lo), + Expected: strconv.Itoa(hi - lo), + } +} + +// randIntInRange 返回 [min, max] 闭区间的随机整数。 +func randIntInRange(minVal, maxVal int) int { + if maxVal <= minVal { + return minVal + } + return minVal + rand.IntN(maxVal-minVal+1) //nolint:gosec +} + +// validateChallenge 在响应文本中查找 expected 整数答案,返回是否通过校验。 +func validateChallenge(responseText, expected string) bool { + if responseText == "" || expected == "" { + return false + } + matches := monitorChallengeNumberRegex.FindAllString(responseText, -1) + for _, m := range matches { + if m == expected { + return true + } + } + return false +} diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go new file mode 100644 index 0000000000000000000000000000000000000000..33570629d78b78b455faa772fe354a0aa7376520 --- /dev/null +++ b/backend/internal/service/channel_monitor_checker.go @@ -0,0 +1,443 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/tidwall/gjson" +) + +// monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。 +// 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。 +var monitorHTTPClient = newSSRFSafeHTTPClient(monitorRequestTimeout) + +// monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。 +var monitorPingHTTPClient = newSSRFSafeHTTPClient(monitorPingTimeout) + +// newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。 +// 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。 +func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client { + tr := &http.Transport{ + DialContext: safeDialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 16, + IdleConnTimeout: monitorIdleConnTimeout, + TLSHandshakeTimeout: monitorTLSHandshakeTimeout, + ResponseHeaderTimeout: monitorResponseHeaderTimeout, + } + return &http.Client{Timeout: timeout, Transport: tr} +} + +// CheckOptions 承载一次检测的自定义入参。 +// 所有字段都是可选(零值即等价于"用默认行为")。 +type CheckOptions struct { + // ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。 + ExtraHeaders map[string]string + // BodyOverrideMode: off | merge | replace + BodyOverrideMode string + // BodyOverride 在 merge 模式下做浅合并(key 命中黑名单时静默丢弃), + // 在 replace 模式下直接当作完整 body。 + BodyOverride map[string]any +} + +// runCheckForModel 对单个 (provider, model) 做一次完整检测。 +// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。 +// +// opts 承载模板 / 监控快照带来的自定义配置。nil 等同于 "off + 无 extra headers"。 +func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string, opts *CheckOptions) *CheckResult { + res := &CheckResult{ + Model: model, + Status: MonitorStatusError, + CheckedAt: time.Now(), + } + + challenge := generateChallenge() + mode := bodyOverrideMode(opts) + + start := time.Now() + respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt, opts) + latency := time.Since(start) + latencyMs := int(latency / time.Millisecond) + res.LatencyMs = &latencyMs + + if err != nil { + res.Status = MonitorStatusError + res.Message = truncateMessage(sanitizeErrorMessage(err.Error())) + return res + } + if statusCode < 200 || statusCode >= 300 { + // 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空, + // 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。 + res.Status = MonitorStatusError + bodySnippet := truncateForErrorBody(rawBody) + res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, bodySnippet))) + return res + } + + // Replace 模式:跳过 challenge 校验(用户 body 是静态的,challenge 没法嵌入)。 + // 改用「HTTP 2xx + 响应文本(adapter.textPath 抽取)非空」作为 operational 判定。 + // 响应文本为空则降级为 failed(视为上游回了 200 但没实际内容)。 + if mode == MonitorBodyOverrideModeReplace { + if strings.TrimSpace(respText) == "" { + res.Status = MonitorStatusFailed + res.Message = truncateMessage("replace-mode: upstream returned 2xx with empty text") + return res + } + return finalizeOperationalOrDegraded(res, latency, latencyMs) + } + + if !validateChallenge(respText, challenge.Expected) { + res.Status = MonitorStatusFailed + res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText))) + return res + } + + return finalizeOperationalOrDegraded(res, latency, latencyMs) +} + +// finalizeOperationalOrDegraded 负责走到最后一步的 operational/degraded 判定。 +// 拆出来是为了让 runCheckForModel 不超过 30 行。 +func finalizeOperationalOrDegraded(res *CheckResult, latency time.Duration, latencyMs int) *CheckResult { + if latency >= monitorDegradedThreshold { + res.Status = MonitorStatusDegraded + res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs)) + return res + } + res.Status = MonitorStatusOperational + return res +} + +// bodyOverrideMode 归一取 opts.BodyOverrideMode,nil opts / 空串都视为 off。 +func bodyOverrideMode(opts *CheckOptions) string { + if opts == nil || opts.BodyOverrideMode == "" { + return MonitorBodyOverrideModeOff + } + return opts.BodyOverrideMode +} + +// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。 +// 失败时返回 nil(不影响主状态判定)。 +func pingEndpointOrigin(ctx context.Context, endpoint string) *int { + origin, err := extractOrigin(endpoint) + if err != nil || origin == "" { + return nil + } + req, err := http.NewRequestWithContext(ctx, http.MethodHead, origin, nil) + if err != nil { + return nil + } + start := time.Now() + resp, err := monitorPingHTTPClient.Do(req) + if err != nil { + return nil + } + defer func() { _ = resp.Body.Close() }() + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, monitorPingDiscardMaxBytes)) + ms := int(time.Since(start) / time.Millisecond) + return &ms +} + +// providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事: +// - 拼出请求路径(含 model 占位) +// - 序列化请求体 +// - 构造鉴权头 +// - 从响应 JSON 中按 path 提取文本(gjson path) +// +// 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。 +type providerAdapter struct { + buildPath func(model string) string + buildBody func(model, prompt string) ([]byte, error) + buildHeaders func(apiKey string) map[string]string + textPath string // gjson 提取响应文本的 path +} + +// providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。 +// +//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。 +var providerAdapters = map[string]providerAdapter{ + MonitorProviderOpenAI: { + buildPath: func(string) string { return providerOpenAIPath }, + buildBody: func(model, prompt string) ([]byte, error) { + return json.Marshal(map[string]any{ + "model": model, + "messages": []map[string]string{{"role": "user", "content": prompt}}, + "max_tokens": monitorChallengeMaxTokens, + "stream": false, + }) + }, + buildHeaders: func(apiKey string) map[string]string { + return map[string]string{"Authorization": "Bearer " + apiKey} + }, + textPath: "choices.0.message.content", + }, + MonitorProviderAnthropic: { + buildPath: func(string) string { return providerAnthropicPath }, + buildBody: func(model, prompt string) ([]byte, error) { + return json.Marshal(map[string]any{ + "model": model, + "messages": []map[string]string{{"role": "user", "content": prompt}}, + "max_tokens": monitorChallengeMaxTokens, + }) + }, + buildHeaders: func(apiKey string) map[string]string { + return map[string]string{ + "x-api-key": apiKey, + "anthropic-version": monitorAnthropicAPIVersion, + } + }, + textPath: "content.0.text", + }, + MonitorProviderGemini: { + // Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent + buildPath: func(model string) string { return fmt.Sprintf(providerGeminiPathTemplate, model) }, + buildBody: func(_, prompt string) ([]byte, error) { + return json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"parts": []map[string]any{{"text": prompt}}}, + }, + "generationConfig": map[string]any{"maxOutputTokens": monitorChallengeMaxTokens}, + }) + }, + // 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。 + buildHeaders: func(apiKey string) map[string]string { + return map[string]string{"x-goog-api-key": apiKey} + }, + textPath: "candidates.0.content.parts.0.text", + }, +} + +// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。 +// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。 +func isSupportedProvider(p string) bool { + _, ok := providerAdapters[p] + return ok +} + +// callProvider 通过 providerAdapters 分发到具体实现。 +// opts 承载用户的自定义 headers / body 覆盖(可为 nil)。 +// +// 返回值: +// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串 +// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包 +// - status: HTTP 状态码 +// - err: 网络 / 序列化错误 +func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) { + adapter, ok := providerAdapters[provider] + if !ok { + return "", "", 0, fmt.Errorf("unsupported provider %q", provider) + } + body, err := buildRequestBody(adapter, provider, model, prompt, opts) + if err != nil { + return "", "", 0, err + } + headers := mergeHeaders(adapter.buildHeaders(apiKey), opts) + full := joinURL(endpoint, adapter.buildPath(model)) + respBytes, status, err := postRawJSON(ctx, full, body, headers) + if err != nil { + return "", "", status, err + } + return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil +} + +// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。 +// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。 +func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string { + if opts == nil || len(opts.ExtraHeaders) == 0 { + return base + } + out := make(map[string]string, len(base)+len(opts.ExtraHeaders)) + for k, v := range base { + out[k] = v + } + for k, v := range opts.ExtraHeaders { + if IsForbiddenHeaderName(k) { + continue + } + out[k] = v + } + return out +} + +// buildRequestBody 根据 body_override_mode 构造请求 body。 +// +// - off: adapter 默认 body +// - merge: adapter 默认 body 与 BodyOverride 浅合并;BodyOverride 中命中 +// bodyMergeKeyDenyList[provider] 的 key 会被静默丢弃,避免破坏 challenge / model 路由 +// - replace: 直接 marshal BodyOverride 作为完整 body +// +// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。 +func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) { + mode := bodyOverrideMode(opts) + + if mode == MonitorBodyOverrideModeReplace { + if opts == nil || len(opts.BodyOverride) == 0 { + return nil, fmt.Errorf("replace mode: body_override is empty") + } + body, err := json.Marshal(opts.BodyOverride) + if err != nil { + return nil, fmt.Errorf("marshal body_override (replace): %w", err) + } + return body, nil + } + + defaultBody, err := adapter.buildBody(model, prompt) + if err != nil { + return nil, fmt.Errorf("marshal default body: %w", err) + } + if mode != MonitorBodyOverrideModeMerge || opts == nil || len(opts.BodyOverride) == 0 { + return defaultBody, nil + } + + var defaultMap map[string]any + if err := json.Unmarshal(defaultBody, &defaultMap); err != nil { + return nil, fmt.Errorf("unmarshal default body for merge: %w", err) + } + deny := bodyMergeKeyDenyList[provider] + for k, v := range opts.BodyOverride { + if deny[k] { + continue + } + defaultMap[k] = v + } + merged, err := json.Marshal(defaultMap) + if err != nil { + return nil, fmt.Errorf("marshal merged body: %w", err) + } + return merged, nil +} + +// bodyMergeKeyDenyList 在 merge 模式下,禁止用户覆盖这些 provider-specific 的关键字段。 +// 思路抄 check-cx 的 EXCLUDED_METADATA_KEYS:保护 challenge / model 路由不被用户误伤。 +// 用户想动这些字段就用 replace 模式(已知会跳 challenge 校验)。 +// +//nolint:gochecknoglobals // 静态查表,初始化后不变。 +var bodyMergeKeyDenyList = map[string]map[string]bool{ + MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true}, + MonitorProviderAnthropic: {"model": true, "messages": true}, + MonitorProviderGemini: {"contents": true}, +} + +// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。 +// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。 +func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, 0, fmt.Errorf("build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } + + resp, err := monitorHTTPClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("do request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, monitorResponseMaxBytes)) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("read body: %w", err) + } + return respBody, resp.StatusCode, nil +} + +// joinURL 把 base origin 与 path 拼成完整 URL。 +// 容忍 base 末尾有/无斜杠,path 必带前导斜杠。 +func joinURL(base, path string) string { + base = strings.TrimRight(base, "/") + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return base + path +} + +// extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。 +func extractOrigin(endpoint string) (string, error) { + u, err := url.Parse(endpoint) + if err != nil { + return "", err + } + if u.Scheme == "" || u.Host == "" { + return "", errors.New("endpoint missing scheme or host") + } + return u.Scheme + "://" + u.Host, nil +} + +// monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数: +// key / api_key / api-key / access_token / token / authorization / x-api-key。 +// 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。 +var monitorSensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`) + +// monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。 +// 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。 +var monitorAPIKeyPatterns = []struct { + pattern *regexp.Regexp + replace string +}{ + // Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx + {regexp.MustCompile(`sk-ant-[A-Za-z0-9_-]{20,}`), "sk-ant-***REDACTED***"}, + // OpenAI / Anthropic 通用 sk-: sk-xxxxxxx + {regexp.MustCompile(`sk-[A-Za-z0-9-]{20,}`), "sk-***REDACTED***"}, + // Gemini / Google API Key:固定前缀 + 35 位 + {regexp.MustCompile(`AIza[A-Za-z0-9_-]{35}`), "AIza***REDACTED***"}, + // JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature + {regexp.MustCompile(`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`), "eyJ***REDACTED.JWT***"}, +} + +// sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。 +// 处理两类来源: +// 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL) +// 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片 +// +// 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广, +// 监控模块独立维护,避免互相耦合。 +func sanitizeErrorMessage(msg string) string { + if msg == "" { + return msg + } + msg = monitorSensitiveQueryParamRegex.ReplaceAllString(msg, `${1}REDACTED`) + for _, p := range monitorAPIKeyPatterns { + msg = p.pattern.ReplaceAllString(msg, p.replace) + } + return msg +} + +// truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。 +func truncateMessage(msg string) string { + if len(msg) <= monitorMessageMaxBytes { + return msg + } + const ellipsis = "...(truncated)" + cutoff := monitorMessageMaxBytes - len(ellipsis) + if cutoff < 0 { + cutoff = 0 + } + return msg[:cutoff] + ellipsis +} + +// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内, +// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。 +// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。 +func truncateForErrorBody(body string) string { + body = strings.Join(strings.Fields(body), " ") + if len(body) <= monitorErrorBodySnippetMaxBytes { + return body + } + const ellipsis = "...(body truncated)" + cutoff := monitorErrorBodySnippetMaxBytes - len(ellipsis) + if cutoff < 0 { + cutoff = 0 + } + return body[:cutoff] + ellipsis +} diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go new file mode 100644 index 0000000000000000000000000000000000000000..323cf8b70abd95bbe2bab0db609f7524b19f37c3 --- /dev/null +++ b/backend/internal/service/channel_monitor_checker_body_test.go @@ -0,0 +1,173 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client, +// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。 +func swapMonitorHTTPClient(t *testing.T) { + t.Helper() + orig := monitorHTTPClient + monitorHTTPClient = &http.Client{Timeout: 5 * time.Second} + t.Cleanup(func() { monitorHTTPClient = orig }) +} + +// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。 +type captureHandler struct { + lastBody map[string]any + lastHeaders http.Header + respondText string // 写到 Anthropic content[0].text 里(校验用) + status int +} + +func (h *captureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.lastHeaders = r.Header.Clone() + defer func() { _ = r.Body.Close() }() + var parsed map[string]any + _ = json.NewDecoder(r.Body).Decode(&parsed) + h.lastBody = parsed + + if h.status == 0 { + h.status = 200 + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(h.status) + // 构造 Anthropic 格式的响应:content[0].text = h.respondText + _ = json.NewEncoder(w).Encode(map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": h.respondText}, + }, + }) +} + +func setupFakeAnthropic(t *testing.T, handler *captureHandler) string { + t.Helper() + swapMonitorHTTPClient(t) + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + return srv.URL +} + +func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) { + h := &captureHandler{respondText: "the answer is 42"} + endpoint := setupFakeAnthropic(t, h) + + // 跑一次 off 模式(opts=nil),确认默认 body 行为未变 + _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", nil) + + if h.lastBody["model"] != "claude-x" { + t.Errorf("default body should contain model=claude-x, got %v", h.lastBody["model"]) + } + if _, ok := h.lastBody["messages"]; !ok { + t.Error("default body should contain messages") + } + if h.lastHeaders.Get("x-api-key") != "sk-fake" { + t.Errorf("expected adapter's x-api-key header, got %q", h.lastHeaders.Get("x-api-key")) + } +} + +func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) { + h := &captureHandler{respondText: "the answer is 42"} + endpoint := setupFakeAnthropic(t, h) + + opts := &CheckOptions{ + BodyOverrideMode: MonitorBodyOverrideModeMerge, + BodyOverride: map[string]any{ + "system": "You are Claude Code...", + "max_tokens": float64(999), // 应该覆盖默认 50 + "model": "hacked-model", // 应该被黑名单挡住,保留原 model + "messages": []any{}, // 同上,被挡 + }, + ExtraHeaders: map[string]string{ + "User-Agent": "claude-cli/1.0", + "Content-Length": "999", // 黑名单 + "x-custom": "ok", + }, + } + _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts) + + if h.lastBody["system"] != "You are Claude Code..." { + t.Errorf("merge mode should inject system, got %v", h.lastBody["system"]) + } + // max_tokens 覆盖生效 + if mt, ok := h.lastBody["max_tokens"].(float64); !ok || mt != 999 { + t.Errorf("merge mode should override max_tokens to 999, got %v", h.lastBody["max_tokens"]) + } + // model 在黑名单 — 应该保留默认值 + if h.lastBody["model"] != "claude-x" { + t.Errorf("model should be protected by deny list, got %v", h.lastBody["model"]) + } + // messages 在黑名单 — 应该保留默认值(非空) + msgs, _ := h.lastBody["messages"].([]any) + if len(msgs) == 0 { + t.Error("messages should be protected by deny list (kept default, non-empty)") + } + // header 合并 + if h.lastHeaders.Get("User-Agent") != "claude-cli/1.0" { + t.Errorf("extra User-Agent should override, got %q", h.lastHeaders.Get("User-Agent")) + } + if h.lastHeaders.Get("x-custom") != "ok" { + t.Errorf("extra custom header should be present, got %q", h.lastHeaders.Get("x-custom")) + } + // Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。 + // 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。 +} + +func TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped(t *testing.T) { + // replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里, + // 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational + h := &captureHandler{respondText: "any non-empty text"} + endpoint := setupFakeAnthropic(t, h) + + userBody := map[string]any{ + "model": "user-forced-model", + "messages": []any{map[string]any{"role": "user", "content": "hi"}}, + "max_tokens": float64(10), + "system": "You are someone else", + } + opts := &CheckOptions{ + BodyOverrideMode: MonitorBodyOverrideModeReplace, + BodyOverride: userBody, + } + res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts) + + // 请求 body = 用户提供的原样 + if h.lastBody["model"] != "user-forced-model" { + t.Errorf("replace mode should use user's model, got %v", h.lastBody["model"]) + } + if h.lastBody["system"] != "You are someone else" { + t.Errorf("replace mode should use user's system, got %v", h.lastBody["system"]) + } + // challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational + if res.Status != MonitorStatusOperational { + t.Errorf("replace mode with 2xx + non-empty text should be operational, got status=%s message=%q", + res.Status, res.Message) + } +} + +func TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed(t *testing.T) { + h := &captureHandler{respondText: ""} // 上游 200 但 content[0].text 为空 + endpoint := setupFakeAnthropic(t, h) + + opts := &CheckOptions{ + BodyOverrideMode: MonitorBodyOverrideModeReplace, + BodyOverride: map[string]any{"model": "x", "messages": []any{}}, + } + res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts) + + if res.Status != MonitorStatusFailed { + t.Errorf("replace mode with empty text should be failed, got status=%s", res.Status) + } + if !strings.Contains(res.Message, "replace-mode") { + t.Errorf("failure message should hint replace-mode, got %q", res.Message) + } +} diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go new file mode 100644 index 0000000000000000000000000000000000000000..2e1614f79945aeb3fd40c5828fa16a9ea34bb95d --- /dev/null +++ b/backend/internal/service/channel_monitor_const.go @@ -0,0 +1,142 @@ +package service + +import ( + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +// ChannelMonitor 全局常量。 +// 这些是 MVP 阶段的硬编码值,按需可以提到 config 中。 +const ( + // monitorRequestTimeout 单次模型请求总超时(含 Body 读取)。 + monitorRequestTimeout = 45 * time.Second + // monitorPingTimeout HEAD 请求 endpoint origin 的超时。 + monitorPingTimeout = 8 * time.Second + // monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。 + monitorDegradedThreshold = 6 * time.Second + // monitorHistoryRetentionDays 明细历史保留天数。 + // 60s 默认间隔 * 30 天 ≈ 43200 行/monitor/model,一般部署总量 <= 2M 行, + // PG 无压力;所以直接保留完整明细一个月,可用率查询可以全走原始行不依赖聚合。 + // 聚合表 channel_monitor_daily_rollups 仍然保留,作为长期历史回填/降级查询的兜底。 + monitorHistoryRetentionDays = 30 + // monitorRollupRetentionDays 日聚合保留天数。 + // 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。 + monitorRollupRetentionDays = 30 + // monitorMaintenanceMaxDaysPerRun 单次维护任务最多聚合的天数。 + // 用于限制首次上线回填(30 天)+ 少量余量,避免长事务。 + monitorMaintenanceMaxDaysPerRun = 35 + // monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。 + monitorWorkerConcurrency = 5 + // monitorStartupLoadTimeout Start 时一次性加载所有 enabled monitor 的总超时。 + monitorStartupLoadTimeout = 10 * time.Second + // monitorMinIntervalSeconds / monitorMaxIntervalSeconds 用户配置的检测间隔上下限。 + monitorMinIntervalSeconds = 15 + monitorMaxIntervalSeconds = 3600 + // monitorMessageMaxBytes message 字段最大字节数(与 schema/migration 一致)。 + monitorMessageMaxBytes = 500 + // monitorResponseMaxBytes 单次模型响应最大读取字节,防止 OOM。 + monitorResponseMaxBytes = 64 * 1024 + // monitorErrorBodySnippetMaxBytes 非 2xx 响应时保留上游 body 片段的最大字节数。 + // 留 300 字节足够覆盖典型结构化错误(如 `{"error":{"message":"..."}}`), + // 又给 "upstream HTTP : " 前缀留出余量,避免最终被 monitorMessageMaxBytes (500) 截得太狠。 + monitorErrorBodySnippetMaxBytes = 300 + // monitorChallengeMin / monitorChallengeMax challenge 操作数范围。 + monitorChallengeMin = 1 + monitorChallengeMax = 50 + + // providerOpenAIPath OpenAI Chat Completions 路径。 + providerOpenAIPath = "/v1/chat/completions" + // providerAnthropicPath Anthropic Messages 路径。 + providerAnthropicPath = "/v1/messages" + // providerGeminiPathTemplate Gemini generateContent 路径模板(含 model 占位)。 + providerGeminiPathTemplate = "/v1beta/models/%s:generateContent" + + // MonitorProviderOpenAI / Anthropic / Gemini provider 字符串常量(也是 ent enum 的实际值)。 + MonitorProviderOpenAI = "openai" + MonitorProviderAnthropic = "anthropic" + MonitorProviderGemini = "gemini" + + // MonitorStatusOperational 等监控状态字符串常量(与 ent enum 一致)。 + MonitorStatusOperational = "operational" + MonitorStatusDegraded = "degraded" + MonitorStatusFailed = "failed" + MonitorStatusError = "error" + + // monitorAvailability7Days / 15 / 30 用于聚合查询窗口。 + monitorAvailability7Days = 7 + monitorAvailability15Days = 15 + monitorAvailability30Days = 30 + + // MonitorHistoryDefaultLimit 历史查询默认返回条数(handler 层共享)。 + MonitorHistoryDefaultLimit = 100 + // MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。 + MonitorHistoryMaxLimit = 1000 + + // monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。 + monitorTimelineMaxPoints = 60 + + // monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。 + monitorEndpointResolveTimeout = 5 * time.Second + + // ---- checker / runner 行为参数(消除 magic 值)---- + + // monitorAnthropicAPIVersion Anthropic Messages API 版本头。 + monitorAnthropicAPIVersion = "2023-06-01" + // monitorChallengeMaxTokens 单次 challenge 请求的 max_tokens(足够回答个位数算术)。 + monitorChallengeMaxTokens = 50 + + // monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。 + monitorRunOneBuffer = 10 * time.Second + + // monitorIdleConnTimeout HTTP transport 空闲连接关闭超时。 + monitorIdleConnTimeout = 30 * time.Second + // monitorTLSHandshakeTimeout HTTP transport TLS 握手超时。 + monitorTLSHandshakeTimeout = 10 * time.Second + // monitorResponseHeaderTimeout HTTP transport 等待响应头超时。 + monitorResponseHeaderTimeout = 30 * time.Second + // monitorPingDiscardMaxBytes ping 时丢弃响应体的最大字节数。 + monitorPingDiscardMaxBytes = 1024 + + // monitorDialTimeout 自定义 dialer 单次连接超时。 + monitorDialTimeout = 10 * time.Second + // monitorDialKeepAlive 自定义 dialer keep-alive 间隔。 + monitorDialKeepAlive = 30 * time.Second +) + +// 业务错误(统一在此声明,避免散落)。 +var ( + ErrChannelMonitorNotFound = infraerrors.NotFound( + "CHANNEL_MONITOR_NOT_FOUND", "channel monitor not found", + ) + ErrChannelMonitorInvalidProvider = infraerrors.BadRequest( + "CHANNEL_MONITOR_INVALID_PROVIDER", "provider must be one of openai/anthropic/gemini", + ) + ErrChannelMonitorInvalidInterval = infraerrors.BadRequest( + "CHANNEL_MONITOR_INVALID_INTERVAL", "interval_seconds must be in [15, 3600]", + ) + ErrChannelMonitorInvalidEndpoint = infraerrors.BadRequest( + "CHANNEL_MONITOR_INVALID_ENDPOINT", "endpoint must be a valid https URL", + ) + ErrChannelMonitorEndpointScheme = infraerrors.BadRequest( + "CHANNEL_MONITOR_ENDPOINT_SCHEME", "endpoint must use https scheme", + ) + ErrChannelMonitorEndpointPath = infraerrors.BadRequest( + "CHANNEL_MONITOR_ENDPOINT_PATH", "endpoint must be base origin only (no path/query/fragment)", + ) + ErrChannelMonitorEndpointPrivate = infraerrors.BadRequest( + "CHANNEL_MONITOR_ENDPOINT_PRIVATE", "endpoint must be a public host", + ) + ErrChannelMonitorEndpointUnreachable = infraerrors.BadRequest( + "CHANNEL_MONITOR_ENDPOINT_UNREACHABLE", "endpoint hostname could not be resolved", + ) + ErrChannelMonitorMissingAPIKey = infraerrors.BadRequest( + "CHANNEL_MONITOR_MISSING_API_KEY", "api_key is required when creating a monitor", + ) + ErrChannelMonitorMissingPrimaryModel = infraerrors.BadRequest( + "CHANNEL_MONITOR_MISSING_PRIMARY_MODEL", "primary_model is required", + ) + ErrChannelMonitorAPIKeyDecryptFailed = infraerrors.InternalServer( + "CHANNEL_MONITOR_KEY_DECRYPT_FAILED", "api key decryption failed; please re-edit the monitor with a fresh key", + ) +) diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go new file mode 100644 index 0000000000000000000000000000000000000000..08178bc65d2d9f70ac0a26d5f99552e91daa4a86 --- /dev/null +++ b/backend/internal/service/channel_monitor_runner.go @@ -0,0 +1,291 @@ +package service + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/alitto/pond/v2" +) + +// MonitorScheduler 调度器接口,供 ChannelMonitorService 在 CRUD 时回调, +// 用 setter 注入避免 service ↔ runner 的 wire 依赖环。 +type MonitorScheduler interface { + // Schedule 为指定监控创建(或重置)独立定时任务。 + // 当 m.Enabled=false 时等同于 Unschedule(m.ID)。 + Schedule(m *ChannelMonitor) + // Unschedule 取消指定监控的定时任务(若存在)。 + Unschedule(id int64) +} + +// monitorRunnerSvc 抽出 runner 实际依赖的两个 service 方法: +// - 启动时加载 enabled monitor +// - 每次 ticker 触发执行检测 +// +// 用接口而非 *ChannelMonitorService 是为了让 runner 单元测试可注入轻量 stub, +// 避免依赖完整的 repo + encryptor 链路。生产实现 *ChannelMonitorService 自然满足。 +type monitorRunnerSvc interface { + ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error) + RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) +} + +// ChannelMonitorRunner 渠道监控调度器。 +// +// 设计: +// - 每个 enabled monitor 对应一个独立 goroutine + ticker(按各自 IntervalSeconds) +// - Start 时一次性加载所有 enabled monitor 并为每个建立任务 +// - Service 在 Create/Update/Delete 后通过 MonitorScheduler 接口回调, +// 即时重建/取消对应任务(无需轮询 DB) +// - 实际 HTTP 检测交给 pond 池(容量 monitorWorkerConcurrency), +// 防止突发并发拖垮上游 +// +// 历史清理与日聚合维护由 OpsCleanupService 的 cron 触发 +// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat), +// 不在 runner 职责内。 +type ChannelMonitorRunner struct { + svc monitorRunnerSvc + settingService *SettingService + + pool pond.Pool + parentCtx context.Context + parentCancel context.CancelFunc + + mu sync.Mutex + tasks map[int64]*scheduledMonitor + wg sync.WaitGroup + started bool + stopped bool + + // inFlight 跟踪正在执行的 monitor.ID。fire 调度前会检查避免重复提交, + // 防止单次检测耗时 > interval 时同一 monitor 被并发执行。 + inFlight map[int64]struct{} + inFlightMu sync.Mutex +} + +// scheduledMonitor 单个监控的运行时上下文。 +type scheduledMonitor struct { + id int64 + name string + interval time.Duration + cancel context.CancelFunc +} + +// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。 +// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。 +// +// pool 在构造时即建好:避免 Start 在 mu 内赋值、fire/Stop 在 mu 外读取的竞态隐患, +// 且 pond.NewPool 创建本身近似零开销,提前建池不会浪费资源。 +func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner { + return newChannelMonitorRunner(svc, settingService) +} + +// newChannelMonitorRunner 内部构造,接受最小化接口,便于单元测试注入 stub。 +func newChannelMonitorRunner(svc monitorRunnerSvc, settingService *SettingService) *ChannelMonitorRunner { + ctx, cancel := context.WithCancel(context.Background()) + return &ChannelMonitorRunner{ + svc: svc, + settingService: settingService, + pool: pond.NewPool(monitorWorkerConcurrency), + parentCtx: ctx, + parentCancel: cancel, + tasks: make(map[int64]*scheduledMonitor), + inFlight: make(map[int64]struct{}), + } +} + +// Start 加载所有 enabled monitor 并为每个建立独立定时任务。 +// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。 +func (r *ChannelMonitorRunner) Start() { + if r == nil || r.svc == nil { + return + } + r.mu.Lock() + if r.started || r.stopped { + r.mu.Unlock() + return + } + r.started = true + r.mu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), monitorStartupLoadTimeout) + defer cancel() + enabled, err := r.svc.ListEnabledMonitors(ctx) + if err != nil { + slog.Error("channel_monitor: load enabled monitors failed at startup", "error", err) + return + } + for _, m := range enabled { + r.Schedule(m) + } + slog.Info("channel_monitor: runner started", "scheduled_tasks", len(enabled)) +} + +// Schedule 为指定监控创建(或重置)独立定时任务。 +// - m.Enabled=false → 等同于 Unschedule(m.ID) +// - 已存在的任务会先被取消再重建(适用于 IntervalSeconds 变更场景) +// - 新任务立即触发首次检测,之后按 IntervalSeconds 周期触发 +func (r *ChannelMonitorRunner) Schedule(m *ChannelMonitor) { + if r == nil || m == nil { + return + } + if !m.Enabled { + r.Unschedule(m.ID) + return + } + interval := time.Duration(m.IntervalSeconds) * time.Second + if interval <= 0 { + // Create/Update 已通过 validateInterval 校验区间,正常路径不可能到这里。 + // 真触发说明数据库中存在违反约束的数据或校验链路有 bug,记 Error 暴露问题。 + slog.Error("channel_monitor: skip schedule for invalid interval", + "monitor_id", m.ID, "interval_seconds", m.IntervalSeconds) + return + } + + r.mu.Lock() + if r.stopped { + r.mu.Unlock() + return + } + if !r.started { + // Start 之前调用 Schedule 通常意味着 wire 顺序错乱: + // 当前 wire 顺序是 SetScheduler → Start,CRUD 钩子最早也只能在请求到达时触发, + // 此时 Start 早已完成。出现此分支时把 monitor 信息打出来便于排查, + // 不入队、不缓存——交给运维通过重启或修复 wire 解决。 + r.mu.Unlock() + slog.Warn("channel_monitor: schedule before runner started, skip", + "monitor_id", m.ID, "name", m.Name) + return + } + if existing, ok := r.tasks[m.ID]; ok { + existing.cancel() + } + ctx, cancel := context.WithCancel(r.parentCtx) + task := &scheduledMonitor{ + id: m.ID, + name: m.Name, + interval: interval, + cancel: cancel, + } + r.tasks[m.ID] = task + r.wg.Add(1) + r.mu.Unlock() + + go r.runScheduled(ctx, task) +} + +// Unschedule 取消指定监控的定时任务(若存在)。 +// 已经在执行中的检测会通过 ctx 取消信号传递。 +func (r *ChannelMonitorRunner) Unschedule(id int64) { + if r == nil { + return + } + r.mu.Lock() + task, ok := r.tasks[id] + if ok { + delete(r.tasks, id) + } + r.mu.Unlock() + if ok { + task.cancel() + } +} + +// Stop 优雅停止:取消所有任务、关闭池。 +func (r *ChannelMonitorRunner) Stop() { + if r == nil { + return + } + r.mu.Lock() + if r.stopped { + r.mu.Unlock() + return + } + r.stopped = true + r.parentCancel() + r.tasks = nil + r.mu.Unlock() + + r.wg.Wait() + r.pool.StopAndWait() +} + +// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"), +// 之后按 interval 周期触发;ctx 取消即退出。 +func (r *ChannelMonitorRunner) runScheduled(ctx context.Context, task *scheduledMonitor) { + defer r.wg.Done() + + r.fire(ctx, task) + + ticker := time.NewTicker(task.interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.fire(ctx, task) + } + } +} + +// fire 提交一次检测到 worker 池。功能开关关闭时跳过本次(不取消任务, +// 重新启用时立即恢复);池满或重复在飞时也跳过。 +func (r *ChannelMonitorRunner) fire(ctx context.Context, task *scheduledMonitor) { + if r.settingService != nil && !r.settingService.GetChannelMonitorRuntime(ctx).Enabled { + return + } + if !r.tryAcquireInFlight(task.id) { + slog.Debug("channel_monitor: skip already in-flight", + "monitor_id", task.id, "name", task.name) + return + } + if _, ok := r.pool.TrySubmit(func() { + r.runOne(task.id, task.name) + }); !ok { + // 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。 + r.releaseInFlight(task.id) + slog.Warn("channel_monitor: worker pool full, skip submission", + "monitor_id", task.id, "name", task.name) + } +} + +// tryAcquireInFlight 原子地占用 monitor 的 in-flight 槽。 +// 已被占用返回 false(调用方应跳过本次提交)。 +func (r *ChannelMonitorRunner) tryAcquireInFlight(id int64) bool { + r.inFlightMu.Lock() + defer r.inFlightMu.Unlock() + if _, exists := r.inFlight[id]; exists { + return false + } + r.inFlight[id] = struct{}{} + return true +} + +// releaseInFlight 释放 in-flight 槽。runOne 完成(含 panic recover)后必须调用。 +func (r *ChannelMonitorRunner) releaseInFlight(id int64) { + r.inFlightMu.Lock() + delete(r.inFlight, id) + r.inFlightMu.Unlock() +} + +// runOne 执行单个监控的检测。所有错误只记日志,不熔断。 +// 任务结束时(含 panic recover)必须释放 in-flight 槽。 +func (r *ChannelMonitorRunner) runOne(id int64, name string) { + ctx, cancel := context.WithTimeout(context.Background(), monitorRequestTimeout+monitorPingTimeout+monitorRunOneBuffer) + defer cancel() + + defer r.releaseInFlight(id) + + defer func() { + if rec := recover(); rec != nil { + slog.Error("channel_monitor: runner panic", + "monitor_id", id, "name", name, "panic", rec) + } + }() + + if _, err := r.svc.RunCheck(ctx, id); err != nil { + slog.Warn("channel_monitor: run check failed", + "monitor_id", id, "name", name, "error", err) + } +} diff --git a/backend/internal/service/channel_monitor_runner_test.go b/backend/internal/service/channel_monitor_runner_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5eed3c2092ab84f5f63e36b353f1aa20ff72fd3c --- /dev/null +++ b/backend/internal/service/channel_monitor_runner_test.go @@ -0,0 +1,277 @@ +//go:build unit + +package service + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +// stubMonitorSvc 实现 monitorRunnerSvc,用于隔离 runner 与真实 service/repo。 +type stubMonitorSvc struct { + enabled []*ChannelMonitor + runCount atomic.Int64 + runCalled chan int64 // 每次 RunCheck 触发时 push 一次(缓冲足够大避免阻塞) + runErr error + listErr error + runHoldFor time.Duration // RunCheck 内额外阻塞的时长,用来测试 Stop 等待行为 +} + +func (s *stubMonitorSvc) ListEnabledMonitors(_ context.Context) ([]*ChannelMonitor, error) { + if s.listErr != nil { + return nil, s.listErr + } + return s.enabled, nil +} + +func (s *stubMonitorSvc) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) { + s.runCount.Add(1) + if s.runCalled != nil { + select { + case s.runCalled <- id: + default: + } + } + if s.runHoldFor > 0 { + select { + case <-time.After(s.runHoldFor): + case <-ctx.Done(): + } + } + return nil, s.runErr +} + +func newRunnerForTest(svc monitorRunnerSvc) *ChannelMonitorRunner { + return newChannelMonitorRunner(svc, nil) +} + +// 等待 condition 在 timeout 内变 true,否则 t.Fatalf。轮询 5ms 一次。 +func waitFor(t *testing.T, timeout time.Duration, msg string, cond func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + if !cond() { + t.Fatalf("waitFor timed out: %s", msg) + } +} + +func runnerTaskCount(r *ChannelMonitorRunner) int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.tasks) +} + +func runnerTaskPtr(r *ChannelMonitorRunner, id int64) *scheduledMonitor { + r.mu.Lock() + defer r.mu.Unlock() + return r.tasks[id] +} + +// TestSchedule_AddsTaskAndFiresOnce 验证 Schedule 后立即触发一次首检测,并把任务记入 tasks 表。 +func TestSchedule_AddsTaskAndFiresOnce(t *testing.T) { + svc := &stubMonitorSvc{runCalled: make(chan int64, 4)} + r := newRunnerForTest(svc) + r.Start() // svc.enabled 为空,Start 立即完成 + + r.Schedule(&ChannelMonitor{ID: 1, Name: "m1", Enabled: true, IntervalSeconds: 60}) + + if got := runnerTaskCount(r); got != 1 { + t.Fatalf("expected 1 scheduled task, got %d", got) + } + + select { + case id := <-svc.runCalled: + if id != 1 { + t.Fatalf("expected first fire for id=1, got %d", id) + } + case <-time.After(2 * time.Second): + t.Fatal("expected immediate first fire within 2s") + } + + r.Stop() +} + +// TestSchedule_ReplaceCancelsOldTask 验证对同一 id 二次 Schedule 会替换旧 task 实例。 +// (旧 goroutine 通过 ctx 取消退出;这里以 task 指针不同 + Stop 不超时作为证据。) +func TestSchedule_ReplaceCancelsOldTask(t *testing.T) { + svc := &stubMonitorSvc{runCalled: make(chan int64, 8)} + r := newRunnerForTest(svc) + r.Start() + + m := &ChannelMonitor{ID: 7, Name: "m7", Enabled: true, IntervalSeconds: 60} + r.Schedule(m) + first := runnerTaskPtr(r, 7) + if first == nil { + t.Fatal("first schedule did not register task") + } + + r.Schedule(m) + second := runnerTaskPtr(r, 7) + if second == nil { + t.Fatal("second schedule did not register task") + } + if first == second { + t.Fatal("re-Schedule should create a new scheduledMonitor instance") + } + + stoppedWithin(t, r, 3*time.Second) +} + +// TestUnschedule_RemovesTask 验证 Unschedule 删除 task 并使对应 goroutine 退出。 +func TestUnschedule_RemovesTask(t *testing.T) { + svc := &stubMonitorSvc{runCalled: make(chan int64, 4)} + r := newRunnerForTest(svc) + r.Start() + + r.Schedule(&ChannelMonitor{ID: 3, Enabled: true, IntervalSeconds: 60}) + waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 }) + + r.Unschedule(3) + if got := runnerTaskCount(r); got != 0 { + t.Fatalf("expected tasks empty after Unschedule, got %d", got) + } + + stoppedWithin(t, r, 3*time.Second) +} + +// TestSchedule_DisabledRedirectsToUnschedule 验证 Enabled=false 等同于 Unschedule。 +func TestSchedule_DisabledRedirectsToUnschedule(t *testing.T) { + svc := &stubMonitorSvc{runCalled: make(chan int64, 4)} + r := newRunnerForTest(svc) + r.Start() + + r.Schedule(&ChannelMonitor{ID: 9, Enabled: true, IntervalSeconds: 60}) + waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 }) + + r.Schedule(&ChannelMonitor{ID: 9, Enabled: false, IntervalSeconds: 60}) + if got := runnerTaskCount(r); got != 0 { + t.Fatalf("expected tasks empty after disabled re-Schedule, got %d", got) + } + + stoppedWithin(t, r, 3*time.Second) +} + +// TestSchedule_InvalidIntervalSkipped 验证 IntervalSeconds<=0 不会注册任务(防御性检查)。 +func TestSchedule_InvalidIntervalSkipped(t *testing.T) { + svc := &stubMonitorSvc{} + r := newRunnerForTest(svc) + r.Start() + + r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 0}) + if got := runnerTaskCount(r); got != 0 { + t.Fatalf("expected no task for invalid interval, got %d", got) + } + r.Stop() +} + +// TestSchedule_BeforeStartIsNoOp 验证 Start 之前调用 Schedule 不会注册任务。 +func TestSchedule_BeforeStartIsNoOp(t *testing.T) { + svc := &stubMonitorSvc{} + r := newRunnerForTest(svc) + // 故意不调用 Start + + r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60}) + if got := runnerTaskCount(r); got != 0 { + t.Fatalf("expected no task before Start, got %d", got) + } + r.Stop() +} + +// TestStart_LoadsAllEnabledMonitors 验证 Start 会为 ListEnabledMonitors 返回的每条记录建立任务。 +func TestStart_LoadsAllEnabledMonitors(t *testing.T) { + svc := &stubMonitorSvc{ + enabled: []*ChannelMonitor{ + {ID: 1, Enabled: true, IntervalSeconds: 60}, + {ID: 2, Enabled: true, IntervalSeconds: 60}, + {ID: 3, Enabled: true, IntervalSeconds: 60}, + }, + } + r := newRunnerForTest(svc) + r.Start() + waitFor(t, 2*time.Second, "all 3 tasks scheduled", func() bool { return runnerTaskCount(r) == 3 }) + + stoppedWithin(t, r, 3*time.Second) +} + +// TestStop_DrainsAllGoroutines 验证 Stop 会等待所有调度 goroutine 退出(无游离)。 +func TestStop_DrainsAllGoroutines(t *testing.T) { + svc := &stubMonitorSvc{} + r := newRunnerForTest(svc) + r.Start() + + for id := int64(1); id <= 5; id++ { + r.Schedule(&ChannelMonitor{ID: id, Enabled: true, IntervalSeconds: 60}) + } + waitFor(t, 2*time.Second, "5 tasks scheduled", func() bool { return runnerTaskCount(r) == 5 }) + + stoppedWithin(t, r, 3*time.Second) +} + +// TestStop_WaitsForInFlightCheck 验证 Stop 会等待正在执行的 RunCheck 退出(pool.StopAndWait)。 +func TestStop_WaitsForInFlightCheck(t *testing.T) { + svc := &stubMonitorSvc{ + runCalled: make(chan int64, 1), + runHoldFor: 200 * time.Millisecond, + } + r := newRunnerForTest(svc) + r.Start() + r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60}) + + select { + case <-svc.runCalled: + case <-time.After(2 * time.Second): + t.Fatal("first fire never happened") + } + + start := time.Now() + stoppedWithin(t, r, 3*time.Second) + elapsed := time.Since(start) + // Stop 必须等待 in-flight check 跑完(runHoldFor=200ms),耗时下界约 100ms。 + if elapsed < 100*time.Millisecond { + t.Fatalf("Stop returned too fast (%v); did not wait for in-flight check", elapsed) + } +} + +// TestInFlight_PoolFullReleasesSlot 直接驱动 fire 路径,模拟 pool.TrySubmit 失败时 inFlight 必须释放。 +// 用一个小型 stub pool 替换 r.pool 不便(pond.Pool 是接口但 mock 麻烦), +// 改为:占满 inFlight 后直接 fire,验证不会在 inFlight 空槽时永久卡住。 +func TestInFlight_AcquireReleaseSymmetric(t *testing.T) { + svc := &stubMonitorSvc{} + r := newRunnerForTest(svc) + + if !r.tryAcquireInFlight(42) { + t.Fatal("first acquire should succeed") + } + if r.tryAcquireInFlight(42) { + t.Fatal("second acquire (no release) must fail") + } + r.releaseInFlight(42) + if !r.tryAcquireInFlight(42) { + t.Fatal("acquire after release should succeed") + } + r.releaseInFlight(42) +} + +// stoppedWithin 在 timeout 内并行调用 Stop,超时则 Fatal。验证 Stop 不会阻塞。 +func stoppedWithin(t *testing.T, r *ChannelMonitorRunner, timeout time.Duration) { + t.Helper() + done := make(chan struct{}) + var once sync.Once + go func() { + r.Stop() + once.Do(func() { close(done) }) + }() + select { + case <-done: + case <-time.After(timeout): + t.Fatalf("Stop did not return within %s — leaked goroutine?", timeout) + } +} diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go new file mode 100644 index 0000000000000000000000000000000000000000..7050e141894c110f62c216ed0d7615725c46ad7c --- /dev/null +++ b/backend/internal/service/channel_monitor_service.go @@ -0,0 +1,539 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "golang.org/x/sync/errgroup" +) + +// ChannelMonitorRepository 渠道监控数据访问接口。 +// 入参/返回的指针类型均使用 service 包的 ChannelMonitor 模型, +// repository 实现负责与 ent 模型互转,并保持 api_key_encrypted 字段为密文。 +type ChannelMonitorRepository interface { + // CRUD + Create(ctx context.Context, m *ChannelMonitor) error + GetByID(ctx context.Context, id int64) (*ChannelMonitor, error) + Update(ctx context.Context, m *ChannelMonitor) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error) + + // 调度器辅助 + ListEnabled(ctx context.Context) ([]*ChannelMonitor, error) + MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error + InsertHistoryBatch(ctx context.Context, rows []*ChannelMonitorHistoryRow) error + DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) + + // 历史记录 + ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error) + + // 用户视图聚合 + ListLatestPerModel(ctx context.Context, monitorID int64) ([]*ChannelMonitorLatest, error) + ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*ChannelMonitorAvailability, error) + + // 批量聚合(admin/user list 用,避免 N+1) + ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*ChannelMonitorLatest, error) + ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*ChannelMonitorAvailability, error) + // ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。 + // 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。 + ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error) + + // ---------- 聚合维护(OpsCleanupService 调用) ---------- + + // UpsertDailyRollupsFor 把 targetDate 当天的明细按 (monitor_id, model, bucket_date) + // 聚合到 channel_monitor_daily_rollups。targetDate 会被截断到日期; + // 用 ON CONFLICT DO UPDATE 实现幂等回填,返回 upsert 影响的行数。 + UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) + // DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行,返回删除行数。 + DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) + // LoadAggregationWatermark 读 watermark(id=1)。 + // 返回 nil 表示从未聚合过;watermark 表本身预期已存在单行(migration 110 写入)。 + LoadAggregationWatermark(ctx context.Context) (*time.Time, error) + // UpdateAggregationWatermark 写 watermark(UPSERT 到 id=1)。 + UpdateAggregationWatermark(ctx context.Context, date time.Time) error +} + +// ChannelMonitorService 渠道监控管理服务。 +type ChannelMonitorService struct { + repo ChannelMonitorRepository + encryptor SecretEncryptor + // scheduler 由 wire 通过 SetScheduler 注入;CRUD 后调用对应钩子即时同步任务。 + // 测试或未注入场景下保持 nil,所有钩子调用变为 no-op。 + scheduler MonitorScheduler +} + +// NewChannelMonitorService 创建渠道监控服务实例。 +func NewChannelMonitorService(repo ChannelMonitorRepository, encryptor SecretEncryptor) *ChannelMonitorService { + return &ChannelMonitorService{repo: repo, encryptor: encryptor} +} + +// ---------- CRUD ---------- + +// List 列表查询(支持 provider/enabled/search 过滤 + 分页)。 +// 返回的 ChannelMonitor.APIKey 已解密为明文,handler 层负责脱敏。 +func (s *ChannelMonitorService) List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error) { + if params.Page < 1 { + params.Page = 1 + } + if params.PageSize < 1 || params.PageSize > 200 { + params.PageSize = 20 + } + items, total, err := s.repo.List(ctx, params) + if err != nil { + return nil, 0, fmt.Errorf("list channel monitors: %w", err) + } + for _, it := range items { + s.decryptInPlace(it) + } + return items, total, nil +} + +// Get 查询单个监控(解密 API Key)。 +func (s *ChannelMonitorService) Get(ctx context.Context, id int64) (*ChannelMonitor, error) { + m, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + s.decryptInPlace(m) + return m, nil +} + +// Create 创建监控(内部加密 api_key)。 +func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCreateParams) (*ChannelMonitor, error) { + if err := validateCreateParams(p); err != nil { + return nil, err + } + if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil { + return nil, err + } + if err := validateExtraHeaders(p.ExtraHeaders); err != nil { + return nil, err + } + encrypted, err := s.encryptor.Encrypt(p.APIKey) + if err != nil { + return nil, fmt.Errorf("encrypt api key: %w", err) + } + m := &ChannelMonitor{ + Name: strings.TrimSpace(p.Name), + Provider: p.Provider, + Endpoint: normalizeEndpoint(p.Endpoint), + APIKey: encrypted, // 注意:传入 repository 时该字段为密文 + PrimaryModel: strings.TrimSpace(p.PrimaryModel), + ExtraModels: normalizeModels(p.ExtraModels), + GroupName: strings.TrimSpace(p.GroupName), + Enabled: p.Enabled, + IntervalSeconds: p.IntervalSeconds, + CreatedBy: p.CreatedBy, + TemplateID: p.TemplateID, + ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders), + BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode), + BodyOverride: p.BodyOverride, + } + if err := s.repo.Create(ctx, m); err != nil { + return nil, fmt.Errorf("create channel monitor: %w", err) + } + // 不再调 s.Get 重走解密链:已知刚加密的明文,直接构造响应。 + // 这样可避免 SecretEncryptor 解密失败时 APIKey 被静默清空的问题(见 Fix 4)。 + m.APIKey = strings.TrimSpace(p.APIKey) + if s.scheduler != nil { + s.scheduler.Schedule(m) + } + return m, nil +} + +// validateCreateParams 把 Create 入参的所有校验聚拢为一个函数,避免 Create 主体超过 30 行。 +func validateCreateParams(p ChannelMonitorCreateParams) error { + if err := validateProvider(p.Provider); err != nil { + return err + } + if err := validateInterval(p.IntervalSeconds); err != nil { + return err + } + if err := validateEndpoint(p.Endpoint); err != nil { + return err + } + if strings.TrimSpace(p.APIKey) == "" { + return ErrChannelMonitorMissingAPIKey + } + if strings.TrimSpace(p.PrimaryModel) == "" { + return ErrChannelMonitorMissingPrimaryModel + } + return nil +} + +// Update 更新监控。APIKey 字段:nil 或空字符串 = 不修改;非空 = 加密后覆盖。 +func (s *ChannelMonitorService) Update(ctx context.Context, id int64, p ChannelMonitorUpdateParams) (*ChannelMonitor, error) { + existing, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + if err := applyMonitorUpdate(existing, p); err != nil { + return nil, err + } + + newPlainAPIKey, apiKeyUpdated, err := s.applyAPIKeyUpdate(existing, p.APIKey) + if err != nil { + return nil, err + } + + if err := s.repo.Update(ctx, existing); err != nil { + return nil, fmt.Errorf("update channel monitor: %w", err) + } + + // 不再调 s.Get 重走解密链:避免二次解密带来的"密文被静默清空"风险(与 Create 一致)。 + if apiKeyUpdated { + existing.APIKey = newPlainAPIKey + } else { + s.decryptInPlace(existing) + } + if s.scheduler != nil { + // Schedule 内部根据 Enabled 自动选择 Unschedule 或重建任务, + // IntervalSeconds 变化也会被自然吸收(旧 task 取消 + 新 task 用新 interval)。 + s.scheduler.Schedule(existing) + } + return existing, nil +} + +// applyAPIKeyUpdate 处理 Update 中的 APIKey 字段: +// - 入参 raw 为 nil 或空白:不修改 existing.APIKey(仍为密文),返回 updated=false +// - 非空:加密后写入 existing.APIKey;同时把明文返回给调用方, +// 供写库成功后塞回 existing 避免把密文吐回客户端 +func (s *ChannelMonitorService) applyAPIKeyUpdate(existing *ChannelMonitor, raw *string) (plain string, updated bool, err error) { + if raw == nil || strings.TrimSpace(*raw) == "" { + return "", false, nil + } + plain = strings.TrimSpace(*raw) + encrypted, encErr := s.encryptor.Encrypt(plain) + if encErr != nil { + return "", false, fmt.Errorf("encrypt api key: %w", encErr) + } + existing.APIKey = encrypted + return plain, true, nil +} + +// Delete 删除监控(历史通过外键 CASCADE 自动清理)。 +func (s *ChannelMonitorService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete channel monitor: %w", err) + } + if s.scheduler != nil { + s.scheduler.Unschedule(id) + } + return nil +} + +// ListHistory 列出某个监控最近的检测历史。 +// model 为空表示返回所有模型;limit <= 0 时使用默认值,超过上限会被截断。 +func (s *ChannelMonitorService) ListHistory(ctx context.Context, id int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error) { + if _, err := s.repo.GetByID(ctx, id); err != nil { + return nil, err + } + if limit <= 0 { + limit = MonitorHistoryDefaultLimit + } + if limit > MonitorHistoryMaxLimit { + limit = MonitorHistoryMaxLimit + } + entries, err := s.repo.ListHistory(ctx, id, strings.TrimSpace(model), limit) + if err != nil { + return nil, fmt.Errorf("list history: %w", err) + } + return entries, nil +} + +// ---------- 业务 ---------- + +// RunCheck 同步触发对一个监控的检测:并发跑 primary + extra 模型, +// 写历史记录并更新 last_checked_at。返回每个模型的检测结果。 +func (s *ChannelMonitorService) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) { + m, err := s.Get(ctx, id) // 已解密 APIKey + if err != nil { + return nil, err + } + if m.APIKeyDecryptFailed { + return nil, ErrChannelMonitorAPIKeyDecryptFailed + } + results := s.runChecksConcurrent(ctx, m) + s.persistCheckResults(ctx, m, results) + return results, nil +} + +// persistCheckResults 写入本次检测的历史记录并更新 last_checked_at。 +// 任一写库失败都只记日志,不影响调用方拿到 results(与 MVP 期望一致:宁可漏记历史也要先返回结果)。 +func (s *ChannelMonitorService) persistCheckResults(ctx context.Context, m *ChannelMonitor, results []*CheckResult) { + rows := make([]*ChannelMonitorHistoryRow, 0, len(results)) + for _, r := range results { + rows = append(rows, &ChannelMonitorHistoryRow{ + MonitorID: m.ID, + Model: r.Model, + Status: r.Status, + LatencyMs: r.LatencyMs, + PingLatencyMs: r.PingLatencyMs, + Message: r.Message, + CheckedAt: r.CheckedAt, + }) + } + if err := s.repo.InsertHistoryBatch(ctx, rows); err != nil { + slog.Error("channel_monitor: insert history failed", + "monitor_id", m.ID, "name", m.Name, "error", err) + } + if err := s.repo.MarkChecked(ctx, m.ID, time.Now()); err != nil { + slog.Error("channel_monitor: mark checked failed", + "monitor_id", m.ID, "error", err) + } +} + +// runChecksConcurrent 对 primary + extra 模型并发执行检测。 +// errgroup 仅用于等待,不传播错误(每个 model 失败都已打包进 CheckResult)。 +func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *ChannelMonitor) []*CheckResult { + models := append([]string{m.PrimaryModel}, m.ExtraModels...) + results := make([]*CheckResult, len(models)) + + // ping 共享一次,所有模型记录同一个 ping 延迟。 + pingMs := pingEndpointOrigin(ctx, m.Endpoint) + + // 所有模型共用同一份 CheckOptions(来自监控的快照字段)。 + opts := &CheckOptions{ + ExtraHeaders: m.ExtraHeaders, + BodyOverrideMode: m.BodyOverrideMode, + BodyOverride: m.BodyOverride, + } + + var eg errgroup.Group + var mu sync.Mutex + for i, model := range models { + i, model := i, model + eg.Go(func() error { + r := runCheckForModel(ctx, m.Provider, m.Endpoint, m.APIKey, model, opts) + r.PingLatencyMs = pingMs + mu.Lock() + results[i] = r + mu.Unlock() + return nil + }) + } + _ = eg.Wait() + return results +} + +// ---------- 调度器协作 ---------- + +// SetScheduler 由 wire 在 runner 构造后注入,用于在 CRUD 时即时同步任务表。 +// 通过 setter 注入避免 service ↔ runner 的依赖环。 +func (s *ChannelMonitorService) SetScheduler(sched MonitorScheduler) { + s.scheduler = sched +} + +// ListEnabledMonitors 返回所有 enabled=true 的监控(解密后),供 runner 启动时建立任务表。 +func (s *ChannelMonitorService) ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error) { + all, err := s.repo.ListEnabled(ctx) + if err != nil { + return nil, err + } + for _, m := range all { + s.decryptInPlace(m) + } + return all, nil +} + +// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。 +// 由 RunDailyMaintenance 调用;SoftDeleteMixin 自动把 DELETE 改为 UPDATE deleted_at。 +func (s *ChannelMonitorService) cleanupOldHistory(ctx context.Context) error { + before := time.Now().UTC().AddDate(0, 0, -monitorHistoryRetentionDays) + deleted, err := s.repo.DeleteHistoryBefore(ctx, before) + if err != nil { + return fmt.Errorf("delete history before %s: %w", before.Format(time.RFC3339), err) + } + if deleted > 0 { + slog.Info("channel_monitor: history cleanup", + "deleted_rows", deleted, "before", before.Format(time.RFC3339)) + } + return nil +} + +// RunDailyMaintenance 每日维护任务:聚合昨天之前未聚合的明细,软删过期明细和聚合。 +// 由 OpsCleanupService 的 cron 调度触发(共享 schedule 和 leader lock)。 +// +// 幂等性: +// - watermark 保证已聚合的日期不会重复处理; +// - UpsertDailyRollupsFor 内部使用 ON CONFLICT DO UPDATE,同一日重复跑结果一致。 +// +// 每一步失败都只记 slog.Warn,整体函数始终返回 nil 让后续步骤能继续跑 +// (与 OpsCleanupService.runCleanupOnce 风格一致)。 +func (s *ChannelMonitorService) RunDailyMaintenance(ctx context.Context) error { + now := time.Now().UTC() + today := now.Truncate(24 * time.Hour) + + if err := s.runDailyAggregation(ctx, today); err != nil { + slog.Warn("channel_monitor: maintenance step failed", + "step", "aggregate", "error", err) + } + if err := s.cleanupOldHistory(ctx); err != nil { + slog.Warn("channel_monitor: maintenance step failed", + "step", "prune_history", "error", err) + } + if err := s.cleanupOldRollups(ctx, today); err != nil { + slog.Warn("channel_monitor: maintenance step failed", + "step", "prune_rollups", "error", err) + } + return nil +} + +// runDailyAggregation 从 watermark+1 聚合到昨天(UTC)。 +// 首次跑(watermark nil):从 today-monitorRollupRetentionDays 开始回填。 +// 每次最多聚合 monitorMaintenanceMaxDaysPerRun 天,避免长事务。 +func (s *ChannelMonitorService) runDailyAggregation(ctx context.Context, today time.Time) error { + watermark, err := s.repo.LoadAggregationWatermark(ctx) + if err != nil { + return fmt.Errorf("load watermark: %w", err) + } + + start := s.resolveAggregationStart(watermark, today) + if !start.Before(today) { + return nil // 没有需要聚合的日期 + } + + iterations := 0 + for d := start; d.Before(today); d = d.Add(24 * time.Hour) { + if iterations >= monitorMaintenanceMaxDaysPerRun { + slog.Info("channel_monitor: maintenance aggregation capped", + "max_days", monitorMaintenanceMaxDaysPerRun, + "next_resume", d.Format("2006-01-02")) + break + } + affected, upErr := s.repo.UpsertDailyRollupsFor(ctx, d) + if upErr != nil { + return fmt.Errorf("upsert rollups for %s: %w", d.Format("2006-01-02"), upErr) + } + if err := s.repo.UpdateAggregationWatermark(ctx, d); err != nil { + return fmt.Errorf("update watermark to %s: %w", d.Format("2006-01-02"), err) + } + slog.Info("channel_monitor: rollups upserted", + "date", d.Format("2006-01-02"), "affected_rows", affected) + iterations++ + } + return nil +} + +// resolveAggregationStart 计算本次聚合起点: +// - watermark == nil:today - monitorRollupRetentionDays(首次回填最多 30 天) +// - watermark != nil:*watermark + 1 day +func (s *ChannelMonitorService) resolveAggregationStart(watermark *time.Time, today time.Time) time.Time { + if watermark == nil { + return today.AddDate(0, 0, -monitorRollupRetentionDays) + } + return watermark.UTC().Truncate(24 * time.Hour).Add(24 * time.Hour) +} + +// cleanupOldRollups 软删 bucket_date < today - monitorRollupRetentionDays 的日聚合行。 +func (s *ChannelMonitorService) cleanupOldRollups(ctx context.Context, today time.Time) error { + cutoff := today.AddDate(0, 0, -monitorRollupRetentionDays) + deleted, err := s.repo.DeleteRollupsBefore(ctx, cutoff) + if err != nil { + return fmt.Errorf("delete rollups before %s: %w", cutoff.Format("2006-01-02"), err) + } + if deleted > 0 { + slog.Info("channel_monitor: rollups cleanup", + "deleted_rows", deleted, "before", cutoff.Format("2006-01-02")) + } + return nil +} + +// ---------- helpers ---------- + +// decryptInPlace 把 ChannelMonitor.APIKey 从密文解密为明文。 +// 解密失败时把字段清空 + 设置 APIKeyDecryptFailed=true(不返回错误,避免阻断列表渲染)。 +// runner / RunCheck 必须读取该标志位并拒绝执行检测。 +func (s *ChannelMonitorService) decryptInPlace(m *ChannelMonitor) { + if m == nil || m.APIKey == "" { + return + } + plain, err := s.encryptor.Decrypt(m.APIKey) + if err != nil { + slog.Warn("channel_monitor: decrypt api key failed", + "monitor_id", m.ID, "error", err) + m.APIKey = "" + m.APIKeyDecryptFailed = true + return + } + m.APIKey = plain +} + +// applyMonitorUpdate 把 update params 中非 nil 的字段应用到 existing 上。 +// APIKey 字段在调用方单独处理(涉及加密)。 +// +// 行数稍超过 30:这是逐字段平铺的 dispatcher,每个 if 都是 1-3 行的"非 nil 则覆盖"模式, +// 拆分反而会增加跳转噪音、影响可读性,故保留为单函数。 +func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error { + if p.Name != nil { + existing.Name = strings.TrimSpace(*p.Name) + } + if p.Provider != nil { + if err := validateProvider(*p.Provider); err != nil { + return err + } + existing.Provider = *p.Provider + } + if p.Endpoint != nil { + if err := validateEndpoint(*p.Endpoint); err != nil { + return err + } + existing.Endpoint = normalizeEndpoint(*p.Endpoint) + } + if p.PrimaryModel != nil { + existing.PrimaryModel = strings.TrimSpace(*p.PrimaryModel) + } + if p.ExtraModels != nil { + existing.ExtraModels = normalizeModels(*p.ExtraModels) + } + if p.GroupName != nil { + existing.GroupName = strings.TrimSpace(*p.GroupName) + } + if p.Enabled != nil { + existing.Enabled = *p.Enabled + } + if p.IntervalSeconds != nil { + if err := validateInterval(*p.IntervalSeconds); err != nil { + return err + } + existing.IntervalSeconds = *p.IntervalSeconds + } + return applyMonitorAdvancedUpdate(existing, p) +} + +// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。 +func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error { + if p.ClearTemplate { + existing.TemplateID = nil + } else if p.TemplateID != nil { + id := *p.TemplateID + existing.TemplateID = &id + } + if p.ExtraHeaders != nil { + if err := validateExtraHeaders(*p.ExtraHeaders); err != nil { + return err + } + existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders) + } + // BodyOverrideMode / BodyOverride 联合校验,和模板一致。 + newMode := existing.BodyOverrideMode + newBody := existing.BodyOverride + if p.BodyOverrideMode != nil { + newMode = *p.BodyOverrideMode + } + if p.BodyOverride != nil { + newBody = *p.BodyOverride + } + if p.BodyOverrideMode != nil || p.BodyOverride != nil { + if err := validateBodyModeParams(newMode, newBody); err != nil { + return err + } + existing.BodyOverrideMode = defaultBodyMode(newMode) + existing.BodyOverride = newBody + } + return nil +} diff --git a/backend/internal/service/channel_monitor_ssrf.go b/backend/internal/service/channel_monitor_ssrf.go new file mode 100644 index 0000000000000000000000000000000000000000..8d93f6004593f3cb32e691c4831f1881ef5977e3 --- /dev/null +++ b/backend/internal/service/channel_monitor_ssrf.go @@ -0,0 +1,152 @@ +package service + +import ( + "context" + "net" + "strings" +) + +// SSRF 防护 helper: +// - validateEndpoint 在 admin 提交时阻止 http/loopback/私网/云元数据 URL +// - safeDialContext 在 socket 层再次校验真实 IP,防止 DNS rebinding +// +// 已知 cloud metadata hostname 拒绝列表(小写比较)。 +var monitorBlockedHostnames = map[string]struct{}{ + "localhost": {}, + "localhost.localdomain": {}, + "metadata": {}, + "metadata.google.internal": {}, + "metadata.goog": {}, + "instance-data": {}, + "instance-data.ec2.internal": {}, +} + +// CIDR 列表:包含所有需要拒绝的 IPv4/IPv6 段。 +// 解析时只 panic 一次(启动时确认),生产路径只做 Contains。 +var monitorBlockedCIDRs = mustParseCIDRs([]string{ + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC1918 + "172.16.0.0/12", // RFC1918 + "192.168.0.0/16", // RFC1918 + "169.254.0.0/16", // link-local(含云元数据 169.254.169.254) + "100.64.0.0/10", // CGNAT + "0.0.0.0/8", // "this network" + "::1/128", // IPv6 loopback + "fc00::/7", // IPv6 ULA + "fe80::/10", // IPv6 link-local + "::/128", // IPv6 unspecified +}) + +// monitorDialer 共享 Dialer,与 net/http 默认值对齐。 +var monitorDialer = &net.Dialer{ + Timeout: monitorDialTimeout, + KeepAlive: monitorDialKeepAlive, +} + +// mustParseCIDRs 在包初始化时解析 CIDR 字符串,失败 panic。 +func mustParseCIDRs(cidrs []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(cidrs)) + for _, c := range cidrs { + _, n, err := net.ParseCIDR(c) + if err != nil { + panic("channel_monitor_ssrf: invalid CIDR " + c + ": " + err.Error()) + } + out = append(out, n) + } + return out +} + +// isBlockedHostname 判断 hostname 是否命中黑名单。 +func isBlockedHostname(hostname string) bool { + if hostname == "" { + return true + } + _, blocked := monitorBlockedHostnames[strings.ToLower(hostname)] + return blocked +} + +// isPrivateIP 判断 IP 是否落在禁止段(loopback/RFC1918/link-local/ULA 等)。 +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return true + } + if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() { + return true + } + for _, n := range monitorBlockedCIDRs { + if n.Contains(ip) { + return true + } + } + return false +} + +// isPrivateOrLoopbackHost 解析 hostname 的所有 A/AAAA 记录, +// 任一 IP 落在私网/loopback 段即认为不安全。 +// +// hostname 是 IP 字面量时也走同一路径。 +func isPrivateOrLoopbackHost(ctx context.Context, hostname string) (bool, error) { + if isBlockedHostname(hostname) { + return true, nil + } + // IP 字面量直接判断。 + if ip := net.ParseIP(hostname); ip != nil { + return isPrivateIP(ip), nil + } + resolver := net.DefaultResolver + addrs, err := resolver.LookupIPAddr(ctx, hostname) + if err != nil { + return false, err + } + if len(addrs) == 0 { + return true, nil + } + for _, a := range addrs { + if isPrivateIP(a.IP) { + return true, nil + } + } + return false, nil +} + +// safeDialContext 在真实 dial 前再次校验目标 IP,防止 DNS rebinding。 +// 解析 hostname 后逐个 IP 尝试连接,命中私网即拒绝(即便 validateEndpoint 时返回的是公网 IP)。 +func safeDialContext(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + // 字面量 IP 走快速路径。 + if ip := net.ParseIP(host); ip != nil { + if isPrivateIP(ip) { + return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address} + } + return monitorDialer.DialContext(ctx, network, address) + } + if isBlockedHostname(host) { + return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address} + } + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(addrs) == 0 { + return nil, &net.AddrError{Err: "no addresses for host", Addr: host} + } + var lastErr error + for _, a := range addrs { + if isPrivateIP(a.IP) { + lastErr = &net.AddrError{Err: "blocked by SSRF policy", Addr: a.IP.String()} + continue + } + conn, err := monitorDialer.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port)) + if err == nil { + return conn, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = &net.AddrError{Err: "no usable addresses", Addr: host} + } + return nil, lastErr +} diff --git a/backend/internal/service/channel_monitor_template_service.go b/backend/internal/service/channel_monitor_template_service.go new file mode 100644 index 0000000000000000000000000000000000000000..8d2e8173f64dae32af5988e558c0eddb45f15536 --- /dev/null +++ b/backend/internal/service/channel_monitor_template_service.go @@ -0,0 +1,251 @@ +package service + +import ( + "context" + "fmt" + "regexp" + "strings" +) + +// ChannelMonitorRequestTemplateRepository 模板数据访问接口。 +type ChannelMonitorRequestTemplateRepository interface { + Create(ctx context.Context, t *ChannelMonitorRequestTemplate) error + GetByID(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) + Update(ctx context.Context, t *ChannelMonitorRequestTemplate) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error) + // ApplyToMonitors 把模板当前的 extra_headers / body_override_mode / body_override + // 批量覆盖到指定 monitorIDs 的监控上(同时还要求这些监控当前 template_id = id, + // 防止误覆盖未关联的监控)。monitorIDs 必须非空;空列表直接返回 0 不写库。 + // 返回被覆盖的监控数量。 + ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) + // CountAssociatedMonitors 统计 template_id = id 的监控数(用于 UI 展示「应用到 N 个配置」)。 + CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) + // ListAssociatedMonitors 列出所有 template_id = id 的监控简略信息(id/name/provider/enabled) + // 给 apply picker UI 用,避免前端再做一次 list+filter。 + ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error) +} + +// AssociatedMonitorBrief 模板关联监控的简略信息(picker / 列表展示用)。 +type AssociatedMonitorBrief struct { + ID int64 + Name string + Provider string + Enabled bool +} + +// ChannelMonitorRequestTemplateService 模板管理 service。 +type ChannelMonitorRequestTemplateService struct { + repo ChannelMonitorRequestTemplateRepository +} + +// NewChannelMonitorRequestTemplateService 创建模板 service。 +func NewChannelMonitorRequestTemplateService(repo ChannelMonitorRequestTemplateRepository) *ChannelMonitorRequestTemplateService { + return &ChannelMonitorRequestTemplateService{repo: repo} +} + +// ---------- CRUD ---------- + +// List 按 provider 过滤(空串 = 全部),不分页(模板量级小)。 +func (s *ChannelMonitorRequestTemplateService) List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error) { + if params.Provider != "" { + if err := validateProvider(params.Provider); err != nil { + return nil, err + } + } + return s.repo.List(ctx, params) +} + +// Get 返回单个模板。 +func (s *ChannelMonitorRequestTemplateService) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) { + return s.repo.GetByID(ctx, id) +} + +// Create 创建模板(会校验 headers 黑名单和 body 模式匹配)。 +func (s *ChannelMonitorRequestTemplateService) Create(ctx context.Context, p ChannelMonitorRequestTemplateCreateParams) (*ChannelMonitorRequestTemplate, error) { + if err := validateTemplateCreateParams(p); err != nil { + return nil, err + } + t := &ChannelMonitorRequestTemplate{ + Name: strings.TrimSpace(p.Name), + Provider: p.Provider, + Description: strings.TrimSpace(p.Description), + ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders), + BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode), + BodyOverride: p.BodyOverride, + } + if err := s.repo.Create(ctx, t); err != nil { + return nil, fmt.Errorf("create template: %w", err) + } + return t, nil +} + +// Update 更新模板(provider 不可改)。 +func (s *ChannelMonitorRequestTemplateService) Update(ctx context.Context, id int64, p ChannelMonitorRequestTemplateUpdateParams) (*ChannelMonitorRequestTemplate, error) { + existing, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + if err := applyTemplateUpdate(existing, p); err != nil { + return nil, err + } + if err := s.repo.Update(ctx, existing); err != nil { + return nil, fmt.Errorf("update template: %w", err) + } + return existing, nil +} + +// Delete 删除模板。关联监控的 template_id 会被 SET NULL,监控保留快照继续跑。 +func (s *ChannelMonitorRequestTemplateService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete template: %w", err) + } + return nil +} + +// ApplyToMonitors 把模板当前配置应用到 monitorIDs 列表里的关联监控。 +// monitorIDs 必须非空且每个 id 都必须当前 template_id = id;不满足条件的会被 SQL WHERE 过滤掉。 +// 返回实际被覆盖的监控数。 +func (s *ChannelMonitorRequestTemplateService) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) { + if _, err := s.repo.GetByID(ctx, id); err != nil { + return 0, err + } + if len(monitorIDs) == 0 { + return 0, ErrChannelMonitorTemplateApplyEmpty + } + affected, err := s.repo.ApplyToMonitors(ctx, id, monitorIDs) + if err != nil { + return 0, fmt.Errorf("apply template to monitors: %w", err) + } + return affected, nil +} + +// CountAssociatedMonitors 返回关联监控数。 +func (s *ChannelMonitorRequestTemplateService) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) { + return s.repo.CountAssociatedMonitors(ctx, id) +} + +// ListAssociatedMonitors 返回模板关联的所有监控简略信息。 +// 给前端 apply picker 用,handler 直接吐 JSON 不再做 join。 +func (s *ChannelMonitorRequestTemplateService) ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error) { + if _, err := s.repo.GetByID(ctx, id); err != nil { + return nil, err + } + return s.repo.ListAssociatedMonitors(ctx, id) +} + +// ---------- 校验 & 工具 ---------- + +// validateTemplateCreateParams 聚合 create 入参校验,避免函数超过 30 行。 +func validateTemplateCreateParams(p ChannelMonitorRequestTemplateCreateParams) error { + if strings.TrimSpace(p.Name) == "" { + return ErrChannelMonitorTemplateMissingName + } + if err := validateProvider(p.Provider); err != nil { + return ErrChannelMonitorTemplateInvalidProvider + } + if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil { + return err + } + if err := validateExtraHeaders(p.ExtraHeaders); err != nil { + return err + } + return nil +} + +// applyTemplateUpdate 把 update params 中非 nil 字段应用到 existing 上。 +func applyTemplateUpdate(existing *ChannelMonitorRequestTemplate, p ChannelMonitorRequestTemplateUpdateParams) error { + if p.Name != nil { + name := strings.TrimSpace(*p.Name) + if name == "" { + return ErrChannelMonitorTemplateMissingName + } + existing.Name = name + } + if p.Description != nil { + existing.Description = strings.TrimSpace(*p.Description) + } + if p.ExtraHeaders != nil { + if err := validateExtraHeaders(*p.ExtraHeaders); err != nil { + return err + } + existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders) + } + // BodyOverrideMode / BodyOverride 联合校验:任一变化都用「更新后的值」做校验。 + newMode := existing.BodyOverrideMode + newBody := existing.BodyOverride + if p.BodyOverrideMode != nil { + newMode = *p.BodyOverrideMode + } + if p.BodyOverride != nil { + newBody = *p.BodyOverride + } + if err := validateBodyModeParams(newMode, newBody); err != nil { + return err + } + existing.BodyOverrideMode = defaultBodyMode(newMode) + existing.BodyOverride = newBody + return nil +} + +// validateBodyModeParams 校验 body_override_mode 合法,且 merge/replace 模式下 body_override 非空。 +func validateBodyModeParams(mode string, body map[string]any) error { + switch mode { + case "", MonitorBodyOverrideModeOff: + return nil + case MonitorBodyOverrideModeMerge, MonitorBodyOverrideModeReplace: + if len(body) == 0 { + return ErrChannelMonitorTemplateBodyRequired + } + return nil + default: + return ErrChannelMonitorTemplateInvalidBodyMode + } +} + +// headerNameRegex 合法 header 名:RFC 7230 token(ASCII 可见字符减特殊符号)。 +var headerNameRegex = regexp.MustCompile(`^[A-Za-z0-9!#$%&'*+\-.^_` + "`" + `|~]+$`) + +// forbiddenHeaderNames hop-by-hop + HTTP 客户端自管的 header;禁止用户覆盖, +// 否则会让 Go http.Client 行为异常(双重 Content-Length、连接复用错乱等)。 +var forbiddenHeaderNames = map[string]bool{ + "host": true, + "content-length": true, + "content-encoding": true, + "transfer-encoding": true, + "connection": true, +} + +// IsForbiddenHeaderName 对外暴露,checker 运行时也会再过滤一次做兜底。 +func IsForbiddenHeaderName(name string) bool { + return forbiddenHeaderNames[strings.ToLower(strings.TrimSpace(name))] +} + +// validateExtraHeaders 校验 header 名字格式 + 黑名单。保存时就拒绝非法 header,早失败。 +func validateExtraHeaders(h map[string]string) error { + for k := range h { + if !headerNameRegex.MatchString(k) { + return ErrChannelMonitorTemplateHeaderInvalidName + } + if IsForbiddenHeaderName(k) { + return ErrChannelMonitorTemplateHeaderForbidden + } + } + return nil +} + +// emptyHeadersIfNil 把 nil map 归一成空 map(repo 层写库时 JSONB 需要非 nil)。 +func emptyHeadersIfNil(h map[string]string) map[string]string { + if h == nil { + return map[string]string{} + } + return h +} + +// defaultBodyMode 空串归一为 off。 +func defaultBodyMode(mode string) string { + if mode == "" { + return MonitorBodyOverrideModeOff + } + return mode +} diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go new file mode 100644 index 0000000000000000000000000000000000000000..e5bf75684436bcc1b17d24ce071f91c7d071574a --- /dev/null +++ b/backend/internal/service/channel_monitor_template_types.go @@ -0,0 +1,77 @@ +package service + +import ( + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "time" +) + +// ChannelMonitorRequestTemplate 请求模板(service 层模型)。 +// 作用:把一组可复用的 headers + 可选 body 覆盖配置抽出来管理, +// 被监控「应用」时以快照方式拷贝到监控本身的同名字段。 +type ChannelMonitorRequestTemplate struct { + ID int64 + Name string + Provider string + Description string + ExtraHeaders map[string]string + BodyOverrideMode string + BodyOverride map[string]any + CreatedAt time.Time + UpdatedAt time.Time +} + +// ChannelMonitorRequestTemplateListParams 列表过滤。 +type ChannelMonitorRequestTemplateListParams struct { + Provider string // 空 = 全部;非空则按 provider 过滤 +} + +// ChannelMonitorRequestTemplateCreateParams 创建参数。 +type ChannelMonitorRequestTemplateCreateParams struct { + Name string + Provider string + Description string + ExtraHeaders map[string]string + BodyOverrideMode string + BodyOverride map[string]any +} + +// ChannelMonitorRequestTemplateUpdateParams 更新参数(指针字段 = 不修改)。 +// 注意 Provider 不可修改:改 provider 会让已关联监控的 body 黑名单语义错乱。 +type ChannelMonitorRequestTemplateUpdateParams struct { + Name *string + Description *string + ExtraHeaders *map[string]string + BodyOverrideMode *string + BodyOverride *map[string]any +} + +// 模板相关错误(命名与现有 ErrChannelMonitor* 风格保持一致)。 +var ( + ErrChannelMonitorTemplateNotFound = infraerrors.NotFound( + "CHANNEL_MONITOR_TEMPLATE_NOT_FOUND", "channel monitor request template not found", + ) + ErrChannelMonitorTemplateInvalidProvider = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_INVALID_PROVIDER", "template provider must be one of openai/anthropic/gemini", + ) + ErrChannelMonitorTemplateMissingName = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_MISSING_NAME", "template name is required", + ) + ErrChannelMonitorTemplateInvalidBodyMode = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_INVALID_BODY_MODE", "body_override_mode must be one of off/merge/replace", + ) + ErrChannelMonitorTemplateBodyRequired = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_BODY_REQUIRED", "body_override is required when body_override_mode is merge or replace", + ) + ErrChannelMonitorTemplateHeaderForbidden = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_HEADER_FORBIDDEN", "header name is forbidden (hop-by-hop or computed by HTTP client)", + ) + ErrChannelMonitorTemplateHeaderInvalidName = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_HEADER_INVALID_NAME", "header name contains invalid characters", + ) + ErrChannelMonitorTemplateProviderMismatch = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_PROVIDER_MISMATCH", "monitor provider does not match template provider", + ) + ErrChannelMonitorTemplateApplyEmpty = infraerrors.BadRequest( + "CHANNEL_MONITOR_TEMPLATE_APPLY_EMPTY", "monitor_ids must be a non-empty array", + ) +) diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go new file mode 100644 index 0000000000000000000000000000000000000000..b797a89b7c774c5343dd0c3b79bc474da32f206a --- /dev/null +++ b/backend/internal/service/channel_monitor_types.go @@ -0,0 +1,203 @@ +package service + +import "time" + +// MonitorBodyOverrideMode 自定义请求体处理模式。 +// +// - off 使用 adapter 默认 body(忽略 BodyOverride) +// - merge adapter 默认 body 与 BodyOverride 浅合并(用户优先; +// model/messages/contents 等关键字段在 checker 黑名单内会被静默丢弃) +// - replace 完全用 BodyOverride 作为 body;跳过 challenge 校验, +// 改成 HTTP 2xx + 响应非空即视为可用(用户负责构造 body) +const ( + MonitorBodyOverrideModeOff = "off" + MonitorBodyOverrideModeMerge = "merge" + MonitorBodyOverrideModeReplace = "replace" +) + +// ChannelMonitor 渠道监控配置(service 层模型,不直接暴露 ent 类型)。 +type ChannelMonitor struct { + ID int64 + Name string + Provider string + Endpoint string + APIKey string // 解密后的明文 API Key(仅在 service 内部使用,handler 层不应直接序列化返回) + PrimaryModel string + ExtraModels []string + GroupName string + Enabled bool + IntervalSeconds int + LastCheckedAt *time.Time + CreatedBy int64 + CreatedAt time.Time + UpdatedAt time.Time + + // 请求自定义快照(来自模板拷贝 or 用户手填,运行时直接读取) + TemplateID *int64 // 仅用于 UI 分组 + 一键应用,运行时不用 + ExtraHeaders map[string]string // 与 adapter 默认 headers 合并,用户优先 + BodyOverrideMode string // off / merge / replace + BodyOverride map[string]any // 仅 mode != off 时使用 + + // APIKeyDecryptFailed 表示 APIKey 字段无法解密(密钥不一致或损坏)。 + // 此时 APIKey 为空字符串,runner / RunCheck 必须跳过该监控并提示重填。 + APIKeyDecryptFailed bool +} + +// ChannelMonitorListParams 列表查询过滤参数。 +type ChannelMonitorListParams struct { + Page int + PageSize int + Provider string + Enabled *bool + Search string +} + +// ChannelMonitorCreateParams 创建参数。 +type ChannelMonitorCreateParams struct { + Name string + Provider string + Endpoint string + APIKey string + PrimaryModel string + ExtraModels []string + GroupName string + Enabled bool + IntervalSeconds int + CreatedBy int64 + TemplateID *int64 + ExtraHeaders map[string]string + BodyOverrideMode string + BodyOverride map[string]any +} + +// ChannelMonitorUpdateParams 更新参数(指针字段表示"未提供则不更新")。 +type ChannelMonitorUpdateParams struct { + Name *string + Provider *string + Endpoint *string + APIKey *string // 空字符串表示不修改;非空字符串覆盖 + PrimaryModel *string + ExtraModels *[]string + GroupName *string + Enabled *bool + IntervalSeconds *int + // 自定义快照字段:指针为 nil 表示不更新,非 nil 覆盖 + // TemplateID *(*int64):用 ** 表达三态:nil=不更新;&nil=清空;&&id=设为 id。 + // 简化处理:用 ClearTemplate 显式标志 + TemplateID(普通指针) + TemplateID *int64 + ClearTemplate bool // true 时无视 TemplateID,把监控的 template_id 置空 + ExtraHeaders *map[string]string + BodyOverrideMode *string + BodyOverride *map[string]any +} + +// CheckResult 单个模型一次检测的结果。 +type CheckResult struct { + Model string + Status string // operational / degraded / failed / error + LatencyMs *int + PingLatencyMs *int + Message string + CheckedAt time.Time +} + +// UserMonitorView 用户只读视图:监控概览(含主模型最近状态 + 7d 可用率 + 附加模型最近状态)。 +type UserMonitorView struct { + ID int64 + Name string + Provider string + GroupName string + PrimaryModel string + PrimaryStatus string + PrimaryLatencyMs *int + PrimaryPingLatencyMs *int // 主模型最近一次 ping 延迟 + Availability7d float64 // 0-100 + ExtraModels []ExtraModelStatus + Timeline []UserMonitorTimelinePoint // 主模型最近 N 个历史点(按 checked_at DESC,最新在前) +} + +// UserMonitorTimelinePoint 用户视图 timeline 单点数据(去除 message 以减小响应体)。 +type UserMonitorTimelinePoint struct { + Status string `json:"status"` + LatencyMs *int `json:"latency_ms"` + PingLatencyMs *int `json:"ping_latency_ms"` + CheckedAt time.Time `json:"checked_at"` +} + +// ExtraModelStatus 附加模型最近一次状态。 +type ExtraModelStatus struct { + Model string + Status string + LatencyMs *int +} + +// UserMonitorDetail 用户只读视图:监控详情(含全部模型 7d/15d/30d 可用率与平均延迟)。 +type UserMonitorDetail struct { + ID int64 + Name string + Provider string + GroupName string + Models []ModelDetail +} + +// ModelDetail 单个模型的可用率/延迟统计。 +type ModelDetail struct { + Model string + LatestStatus string + LatestLatencyMs *int + Availability7d float64 // 0-100 + Availability15d float64 + Availability30d float64 + AvgLatency7dMs *int +} + +// ChannelMonitorHistoryRow 历史记录入库行(service 层向 repository 提交的数据)。 +type ChannelMonitorHistoryRow struct { + MonitorID int64 + Model string + Status string + LatencyMs *int + PingLatencyMs *int + Message string + CheckedAt time.Time +} + +// ChannelMonitorHistoryEntry 历史记录查询返回行(含 ent 主键 ID)。 +type ChannelMonitorHistoryEntry struct { + ID int64 + Model string + Status string + LatencyMs *int + PingLatencyMs *int + Message string + CheckedAt time.Time +} + +// ChannelMonitorLatest 最近一次检测的简明信息(用于 UserMonitorView 聚合)。 +type ChannelMonitorLatest struct { + Model string + Status string + LatencyMs *int + PingLatencyMs *int + CheckedAt time.Time +} + +// ChannelMonitorAvailability 单个模型在某窗口内的可用率与平均延迟(用于 UserMonitorDetail 聚合)。 +type ChannelMonitorAvailability struct { + Model string + WindowDays int + TotalChecks int + OperationalChecks int // operational + degraded 视为可用 + AvailabilityPct float64 + AvgLatencyMs *int +} + +// MonitorStatusSummary 监控状态聚合(admin list 用,单次 repo 查询消除前端 N+1)。 +// PrimaryStatus / PrimaryLatencyMs 描述主模型最近状态;Availability7d 是主模型 7 天可用率; +// ExtraModels 描述附加模型最近状态(用于 hover 展示)。 +type MonitorStatusSummary struct { + PrimaryStatus string // 空字符串表示无历史 + PrimaryLatencyMs *int + Availability7d float64 // 0-100,无历史时为 0 + ExtraModels []ExtraModelStatus +} diff --git a/backend/internal/service/channel_monitor_validate.go b/backend/internal/service/channel_monitor_validate.go new file mode 100644 index 0000000000000000000000000000000000000000..16bbec7174a628219bfc8f84e26966d97935cca1 --- /dev/null +++ b/backend/internal/service/channel_monitor_validate.go @@ -0,0 +1,99 @@ +package service + +import ( + "context" + "net/url" + "strings" +) + +// 渠道监控参数校验与归一化辅助函数。 +// 校验失败一律返回 channel_monitor_const.go 中预定义的 Err* 错误,错误信息不含具体 IP/hostname,避免泄露内网拓扑。 + +// validateProvider 校验 provider 字符串。 +// 唯一来源于 providerAdapters:新增 provider 只需要在 channel_monitor_checker.go 注册 adapter。 +func validateProvider(p string) error { + if !isSupportedProvider(p) { + return ErrChannelMonitorInvalidProvider + } + return nil +} + +// validateInterval 校验 interval_seconds 范围。 +func validateInterval(sec int) error { + if sec < monitorMinIntervalSeconds || sec > monitorMaxIntervalSeconds { + return ErrChannelMonitorInvalidInterval + } + return nil +} + +// validateEndpoint 校验 endpoint: +// - scheme 强制 https(拒绝 http,避免明文凭证 + 部分 SSRF 利用面) +// - 必须为 origin(无 path/query/fragment),防止用户填 https://api.openai.com/v1 +// 导致 joinURL 拼出 /v1/v1/chat/completions +// - hostname 不能是 localhost/metadata 等已知元数据 hostname +// - 解析所有 IP,任一落在 loopback/RFC1918/link-local/ULA 段即拒绝(防 SSRF) +// +// 错误信息不暴露具体 IP / hostname,避免泄露内网拓扑。 +func validateEndpoint(ep string) error { + ep = strings.TrimSpace(ep) + if ep == "" { + return ErrChannelMonitorInvalidEndpoint + } + u, err := url.Parse(ep) + if err != nil { + return ErrChannelMonitorInvalidEndpoint + } + if u.Scheme != "https" { + return ErrChannelMonitorEndpointScheme + } + if u.Host == "" { + return ErrChannelMonitorInvalidEndpoint + } + if u.Path != "" && u.Path != "/" { + return ErrChannelMonitorEndpointPath + } + if u.RawQuery != "" || u.Fragment != "" { + return ErrChannelMonitorEndpointPath + } + + hostname := u.Hostname() + ctx, cancel := context.WithTimeout(context.Background(), monitorEndpointResolveTimeout) + defer cancel() + blocked, err := isPrivateOrLoopbackHost(ctx, hostname) + if err != nil { + return ErrChannelMonitorEndpointUnreachable + } + if blocked { + return ErrChannelMonitorEndpointPrivate + } + return nil +} + +// normalizeEndpoint 去除前后空白与末尾 `/`,保证存储统一为 origin。 +// validateEndpoint 已确保格式合法(仅 origin),这里只做最终归一化。 +func normalizeEndpoint(ep string) string { + ep = strings.TrimSpace(ep) + ep = strings.TrimRight(ep, "/") + return ep +} + +// normalizeModels 去除空白、重复模型名。保留输入顺序(map 的迭代顺序无关)。 +func normalizeModels(in []string) []string { + if len(in) == 0 { + return []string{} + } + seen := make(map[string]struct{}, len(in)) + out := make([]string, 0, len(in)) + for _, m := range in { + m = strings.TrimSpace(m) + if m == "" { + continue + } + if _, ok := seen[m]; ok { + continue + } + seen[m] = struct{}{} + out = append(out, m) + } + return out +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index c29550d90d17a57fe0ec834477cde779427866c5..4e08df4a5699bbc393eca9944f9405f56b68272f 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -141,17 +141,23 @@ const ( // ChannelService 渠道管理服务 type ChannelService struct { repo ChannelRepository + groupRepo GroupRepository authCacheInvalidator APIKeyAuthCacheInvalidator + pricingService *PricingService // 用于「可用渠道」展示时回落到全局定价;可为 nil(测试场景) cache atomic.Value // *channelCache cacheSF singleflight.Group } -// NewChannelService 创建渠道服务实例 -func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { +// NewChannelService 创建渠道服务实例。 +// pricingService 仅供 ListAvailable 在渠道未配置定价时回落到全局 LiteLLM 数据; +// 计费热路径走独立的 ModelPricingResolver,与此参数无关。可传 nil。 +func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, pricingService *PricingService) *ChannelService { s := &ChannelService{ repo: repo, + groupRepo: groupRepo, authCacheInvalidator: authCacheInvalidator, + pricingService: pricingService, } return s } @@ -299,6 +305,9 @@ func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[i } // populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。 +// 装填时对每个 Channel 统一归一化 BillingModelSource,让缓存命中的所有下游 +// (gateway routing / billing / 未来任何 cache-backed 读路径)都拿到已归一化的实体, +// 避免"每个出口各自记得 normalize"反模式。 func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache { cache := newEmptyChannelCache() cache.groupPlatform = groupPlatforms @@ -306,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) * cache.loadedAt = time.Now() for i := range channels { + channels[i].normalizeBillingModelSource() ch := &channels[i] cache.byID[ch.ID] = ch for _, gid := range ch.GroupIDs { @@ -516,14 +526,13 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g // resolveMapping 基于已查找的渠道信息解析模型映射。 // antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。 func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { + // lk.channel 来自已装填的缓存,BillingModelSource 已在 populateChannelCache 阶段归一化, + // 这里无需重复兜底。 result := ChannelMappingResult{ MappedModel: model, ChannelID: lk.channel.ID, BillingModelSource: lk.channel.BillingModelSource, } - if result.BillingModelSource == "" { - result.BillingModelSource = BillingModelSourceChannelMapped - } modelLower := strings.ToLower(model) if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" { @@ -684,9 +693,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ApplyPricingToAccountStats: input.ApplyPricingToAccountStats, AccountStatsPricingRules: input.AccountStatsPricingRules, } - if channel.BillingModelSource == "" { - channel.BillingModelSource = BillingModelSourceChannelMapped - } + channel.normalizeBillingModelSource() if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err @@ -702,12 +709,23 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) } s.invalidateCache() - return s.repo.GetByID(ctx, channel.ID) + created, err := s.repo.GetByID(ctx, channel.ID) + if err != nil { + return nil, err + } + created.normalizeBillingModelSource() + return created, nil } -// GetByID 获取渠道详情 +// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped, +// 让所有 handler 无需重复处理历史空值。 func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) { - return s.repo.GetByID(ctx, id) + ch, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + ch.normalizeBillingModelSource() + return ch, nil } // Update 更新渠道 @@ -739,7 +757,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan s.invalidateCache() s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) - return s.repo.GetByID(ctx, id) + updated, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + updated.normalizeBillingModelSource() + return updated, nil } // applyUpdateInput 将更新请求的字段应用到渠道实体上。 @@ -857,7 +880,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { // List 获取渠道列表 func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { - return s.repo.List(ctx, params, status, search) + channels, res, err := s.repo.List(ctx, params, status, search) + if err != nil { + return nil, nil, err + } + for i := range channels { + channels[i].normalizeBillingModelSource() + } + return channels, res, nil } // modelEntry 表示一个模型模式条目(用于冲突检测) @@ -884,12 +914,7 @@ func conflictsBetween(a, b modelEntry) bool { // toModelEntry 将模型名转换为 modelEntry func toModelEntry(pattern string) modelEntry { - lower := strings.ToLower(pattern) - isWild := strings.HasSuffix(lower, "*") - prefix := lower - if isWild { - prefix = strings.TrimSuffix(lower, "*") - } + prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern)) return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} } diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index e1345618c909a302e6af08b34c3b6bf81f3ef947..e737a21125bec3cc1ae4de50d14978bf21173823 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context // --------------------------------------------------------------------------- func newTestChannelService(repo *mockChannelRepository) *ChannelService { - return NewChannelService(repo, nil) + return NewChannelService(repo, nil, nil, nil) } func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { - return NewChannelService(repo, auth) + return NewChannelService(repo, nil, auth, nil) } // makeStandardRepo returns a repo that serves one active channel with anthropic pricing diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index deac64d629faf870405079ab614be3a714589986..164861fb93d349f0ab17b5ce002f2dcd207b8e40 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -433,3 +433,296 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) { require.Contains(t, err.Error(), "unbounded") require.Contains(t, err.Error(), "last") } + +func TestSupportedModels_ExactKeysAndPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)}, + {ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-opus-4-6": "claude-opus-4-6", + }, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 2) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-opus-4-6", got[0].Name) + require.NotNil(t, got[0].Pricing) + require.Equal(t, int64(11), got[0].Pricing.ID) + require.Equal(t, "claude-sonnet-4-6", got[1].Name) + require.Equal(t, int64(10), got[1].Pricing.ID) +} + +func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}}, + {ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-*": "claude-sonnet-4-6", + }, + }, + } + + got := ch.SupportedModels() + names := make([]string, 0, len(got)) + for _, m := range got { + names = append(names, m.Name) + } + require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6", "claude-opus-4-6"}, names) + for _, m := range got { + require.NotContains(t, m.Name, "*") + } +} + + +func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) { + ch := &Channel{ + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"}, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) + require.Nil(t, got[0].Pricing) +} + +func TestSupportedModels_DedupAndSort(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}}, + {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4-6": "upstream-a", + "claude-sonnet-*": "upstream-a", + }, + "openai": {"gpt-4o": "gpt-4o"}, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 3) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-sonnet-4-5", got[0].Name) + require.Equal(t, "anthropic", got[1].Platform) + require.Equal(t, "claude-sonnet-4-6", got[1].Name) + require.Equal(t, "openai", got[2].Platform) + require.Equal(t, "gpt-4o", got[2].Name) +} + +func TestSupportedModels_NilChannelAndEmpty(t *testing.T) { + var nilCh *Channel + require.Nil(t, nilCh.SupportedModels()) + + empty := &Channel{} + require.Nil(t, empty.SupportedModels()) +} + +func TestGetModelPricingByPlatform(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)}, + {ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)}, + }, + } + + ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6") + require.NotNil(t, ant) + require.Equal(t, int64(1), ant.ID) + + oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6") + require.NotNil(t, oa) + require.Equal(t, int64(2), oa.ID) + + require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6")) +} + +func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) { + // 定价中含通配符条目(pattern),不应被当作具体模型名展开。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) + for _, m := range got { + require.NotContains(t, m.Name, "*") + } +} + +func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) { + // 通配符模式无任何对应定价模型时,该平台 mapping 路不产出; + // 但其他平台的 pricing-only 模型仍会通过 Pass B 出现。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"gpt-foo-*": "gpt-foo-1"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "openai", got[0].Platform) + require.Equal(t, "gpt-4o", got[0].Name) +} + +func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) { + // anthropic 的通配符不应把 openai 定价行拉到 anthropic 平台下; + // openai 的 pricing-only 模型则正常通过 Pass B 暴露在 openai 平台下。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-*": "x"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "openai", got[0].Platform, "不能把 openai 定价标记为 anthropic 模型") + require.Equal(t, "claude-sonnet-4-6", got[0].Name) +} + +func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) { + // 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"GPT-4o"}}, + {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "openai": {"gpt-*": "x"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "GPT-4o", got[0].Name) +} + +func TestSupportedModels_EmptyPlatformMapping(t *testing.T) { + // ModelMapping 平台 key 存在但 value 为空 map:mapping 路跳过该平台, + // 但 pricing 路仍会把该平台的定价模型补齐(关键修复:azcc 这种"只配定价不配映射"渠道)。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) + require.NotNil(t, got[0].Pricing) +} + +func TestSupportedModels_ExactKeyUsesPricedCaseWhenAvailable(t *testing.T) { + // mapping key uses uppercase, pricing uses lowercase — pricing's case should win. + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "openai": {"GPT-4o": "gpt-4o"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "gpt-4o", got[0].Name) // pricing's case wins +} + +func TestSupportedModels_AsteriskOnlyMappingExpandsAllPriced(t *testing.T) { + // 映射 key 为单独的 "*":前缀为空 → 命中该平台所有定价模型(透传场景)。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"gpt-4o", "gpt-4o-mini"}}, + }, + ModelMapping: map[string]map[string]string{ + "openai": {"*": "gpt-4o"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 2) + names := []string{got[0].Name, got[1].Name} + require.ElementsMatch(t, []string{"gpt-4o", "gpt-4o-mini"}, names) +} + +func TestSupportedModels_PricingOnlyNoMapping(t *testing.T) { + // 渠道完全没配 mapping,只配了定价 —— 应该把所有定价模型作为支持模型返回。 + // 这是修复前的核心 bug 场景(前端显示"未配置模型")。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)}, + {ID: 2, Platform: "anthropic", Models: []string{"claude-haiku-4-5"}, InputPrice: testPtrFloat64(3e-7)}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 2) + require.Equal(t, "claude-haiku-4-5", got[0].Name) + require.NotNil(t, got[0].Pricing) + require.Equal(t, int64(2), got[0].Pricing.ID) + require.Equal(t, "claude-opus-4-6", got[1].Name) + require.Equal(t, int64(1), got[1].Pricing.ID) +} + +func TestSupportedModels_ExactMappingUsesTargetPricing(t *testing.T) { + // 精确 mapping `src → target`:定价应按 target 查(实际计费的是 target), + // 而不是按 src 自查。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"req-model"}, InputPrice: testPtrFloat64(3e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"served-model"}, InputPrice: testPtrFloat64(1.5e-5)}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "req-model": "served-model", + }, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 2) + require.Equal(t, "req-model", got[0].Name) + require.NotNil(t, got[0].Pricing) + require.Equal(t, int64(200), got[0].Pricing.ID, "req-model 显示但定价是 served-model 的(mapping target)") + require.Equal(t, "served-model", got[1].Name) + require.Equal(t, int64(200), got[1].Pricing.ID) +} + +func TestSupportedModels_ExactMappingTargetMissingFromPricing(t *testing.T) { + // `src → target` 但 target 不在渠道定价里 —— 结果中 src 的 Pricing 为 nil + // (等待 ListAvailable 阶段的全局 LiteLLM 回落填充)。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"some-priced-model"}, InputPrice: testPtrFloat64(1.5e-5)}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "missing-src": "missing-target", + }, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 2) + require.Equal(t, "missing-src", got[0].Name) + require.Nil(t, got[0].Pricing, "target 在渠道定价中缺失时不虚假填充,留给 ListAvailable 走 LiteLLM 回落") + require.Equal(t, "some-priced-model", got[1].Name) + require.NotNil(t, got[1].Pricing) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cb452efbeec05e037e188c3eea0a0c0c55a7313e..cf47b76f4e62e24365fce01ed0b02e77808d007c 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" +// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" + // Setting keys const ( // 注册设置 @@ -108,6 +111,24 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // WeChat Connect OAuth 登录设置 + SettingKeyWeChatConnectEnabled = "wechat_connect_enabled" + SettingKeyWeChatConnectAppID = "wechat_connect_app_id" + SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret" + SettingKeyWeChatConnectOpenAppID = "wechat_connect_open_app_id" + SettingKeyWeChatConnectOpenAppSecret = "wechat_connect_open_app_secret" + SettingKeyWeChatConnectMPAppID = "wechat_connect_mp_app_id" + SettingKeyWeChatConnectMPAppSecret = "wechat_connect_mp_app_secret" + SettingKeyWeChatConnectMobileAppID = "wechat_connect_mobile_app_id" + SettingKeyWeChatConnectMobileAppSecret = "wechat_connect_mobile_app_secret" + SettingKeyWeChatConnectOpenEnabled = "wechat_connect_open_enabled" + SettingKeyWeChatConnectMPEnabled = "wechat_connect_mp_enabled" + SettingKeyWeChatConnectMobileEnabled = "wechat_connect_mobile_enabled" + SettingKeyWeChatConnectMode = "wechat_connect_mode" + SettingKeyWeChatConnectScopes = "wechat_connect_scopes" + SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url" + SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url" + // Generic OIDC OAuth 登录设置 SettingKeyOIDCConnectEnabled = "oidc_connect_enabled" SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name" @@ -149,9 +170,33 @@ const ( SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 - SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制) + + // 第三方认证来源默认授予配置 + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) @@ -198,6 +243,23 @@ const ( // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= + // Channel Monitor (渠道监控) + // ========================= + + // SettingKeyChannelMonitorEnabled is a DB-backed soft switch for the channel monitor feature. + // When false: runner skips scheduling and user-facing endpoints return an empty list. + SettingKeyChannelMonitorEnabled = "channel_monitor_enabled" + + // SettingKeyChannelMonitorDefaultIntervalSeconds controls the default interval (seconds) + // pre-filled when creating a new channel monitor from the admin UI. Range: [15, 3600]. + SettingKeyChannelMonitorDefaultIntervalSeconds = "channel_monitor_default_interval_seconds" + + // SettingKeyAvailableChannelsEnabled is a DB-backed soft switch for the "Available Channels" + // user-facing aggregate view. When false: user endpoint returns an empty list and the + // sidebar entry is hidden. Defaults to false (opt-in feature). + SettingKeyAvailableChannelsEnabled = "available_channels_enabled" + // ========================= // Overload Cooldown (529) // ========================= diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 55cb2c84548bd3a26393ee2d8365cb737b14f905..498336a4005a5a82c78a1c4edc16f383c679941e 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string { return nil } switch value { - case "low", "medium", "high", "max": + case "low", "medium", "high", "xhigh", "max": return &value default: return nil diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index d262456dc65a40bd53f86d9df0851c04cff65e54..40bd11867280b6a52c30f1fed64d2ed0bb1e5395 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) { body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`, wantEffort: "max", }, + { + name: "output_config.effort xhigh", + body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`, + wantEffort: "xhigh", + }, { name: "output_config without effort", body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`, @@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) { {"LOW", strPtr("low")}, {"Max", strPtr("max")}, {" medium ", strPtr("medium")}, + {"xhigh", strPtr("xhigh")}, + {"XHIGH", strPtr("xhigh")}, {"", nil}, {"unknown", nil}, - {"xhigh", nil}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4b4fc0bf25ec8a2a4f4b2b1c12ce2bac6eabeabd..5a91d0dee358c6a2d4a438d748c423168258c5c5 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i } // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 -// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, -// 或请求的模型处于限流状态时,返回 true。 -// 这确保后续请求不会继续使用不可用的账号。 +// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等), +// 额外检查模型级限流。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. -// Returns true when account status is error/disabled, schedulable is false, -// within temporary unschedulable period, or the requested model is rate-limited. -// This ensures subsequent requests won't continue using unavailable accounts. +// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting. func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { return false } - if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable { + if !account.IsSchedulable() { return true } - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - return true - } - // 检查模型限流和 scope 限流,有限流即清除粘性会话 if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { return true } @@ -7317,8 +7310,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill cost := p.Cost if p.IsSubscriptionBill { - if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + // Subscription usage tracked by ActualCost so group rate multiplier + // consumes the quota at the expected speed. + if cost.ActualCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } } @@ -7417,9 +7412,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage } } + // Record subscription / balance cost using ActualCost so the group (and any + // user-specific) rate multiplier consumes subscription quota at the expected + // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards + // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { cmd.SubscriptionID = &p.Subscription.ID - cmd.SubscriptionCost = p.Cost.TotalCost + cmd.SubscriptionCost = p.Cost.ActualCost } else if p.Cost.ActualCost > 0 { cmd.BalanceCost = p.Cost.ActualCost } @@ -7478,8 +7477,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu } if p.IsSubscriptionBill { - if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) } } else if p.Cost.ActualCost > 0 && p.User != nil { deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go new file mode 100644 index 0000000000000000000000000000000000000000..42a81035c0b10e5f940ab9e8620836d8481a12a1 --- /dev/null +++ b/backend/internal/service/gateway_service_subscription_billing_test.go @@ -0,0 +1,85 @@ +//go:build unit + +package service + +import ( + "testing" +) + +// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix +// that subscription-mode billing honours the group (and any user-specific) rate +// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost * +// RateMultiplier), not raw TotalCost. +func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) { + t.Parallel() + + groupID := int64(7) + subID := int64(42) + + tests := []struct { + name string + totalCost float64 + actualCost float64 + isSubscription bool + wantSub float64 + wantBalance float64 + }{ + { + name: "subscription with 2x multiplier consumes 2x quota", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: true, + wantSub: 2.0, + wantBalance: 0, + }, + { + name: "subscription with 0.5x multiplier consumes 0.5x quota", + totalCost: 1.0, + actualCost: 0.5, + isSubscription: true, + wantSub: 0.5, + wantBalance: 0, + }, + { + name: "free subscription (multiplier 0) consumes no quota", + totalCost: 1.0, + actualCost: 0, + isSubscription: true, + wantSub: 0, + wantBalance: 0, + }, + { + name: "balance billing keeps using ActualCost (regression)", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: false, + wantSub: 0, + wantBalance: 2.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := &postUsageBillingParams{ + Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost}, + User: &User{ID: 1}, + APIKey: &APIKey{ID: 2, GroupID: &groupID}, + Account: &Account{ID: 3}, + Subscription: &UserSubscription{ID: subID}, + IsSubscriptionBill: tt.isSubscription, + } + + cmd := buildUsageBillingCommand("req-1", nil, p) + if cmd == nil { + t.Fatal("buildUsageBillingCommand returned nil") + } + if cmd.SubscriptionCost != tt.wantSub { + t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub) + } + if cmd.BalanceCost != tt.wantBalance { + t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance) + } + }) + } +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 12262613571e347d00a9ed1f4dfbffbd6649c57a..bb4c5aa1ba1b4c2db1a8c899b19ae85f21994cfe 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -59,6 +59,10 @@ type Group struct { DefaultMappedModel string MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。 + // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。 + RPMLimit int + CreatedAt time.Time UpdatedAt time.Time @@ -76,10 +80,6 @@ func (g *Group) IsSubscriptionType() bool { return g.SubscriptionType == SubscriptionTypeSubscription } -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - func (g *Group) HasDailyLimit() bool { return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 } diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go index b7ca4cb76e98af77231252e31f2ccd17c7afbeb2..580897767dd8e182d546013007140890de6bb20c 100644 --- a/backend/internal/service/model_pricing_resolver.go +++ b/backend/internal/service/model_pricing_resolver.go @@ -61,6 +61,25 @@ type PricingInput struct { // 1. 获取基础定价(LiteLLM → Fallback) // 2. 如果指定了 GroupID,查找渠道定价并覆盖 func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing { + var chPricing *ChannelModelPricing + if input.GroupID != nil && r.channelService != nil { + chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model) + if chPricing != nil { + mode := chPricing.BillingMode + if mode == "" { + mode = BillingModeToken + } + if mode == BillingModePerRequest || mode == BillingModeImage { + resolved := &ResolvedPricing{ + Mode: mode, + Source: PricingSourceChannel, + } + r.applyRequestTierOverrides(chPricing, resolved) + return resolved + } + } + } + // 1. 获取基础定价 basePricing, source := r.resolveBasePricing(input.Model) @@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) } // 2. 如果有 GroupID,尝试渠道覆盖 - if input.GroupID != nil { + if chPricing != nil { + resolved.Source = PricingSourceChannel + r.applyTokenOverrides(chPricing, resolved) + } else if input.GroupID != nil { r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved) } diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go index 905c4df68523df81e04e5019278d96decd14099d..4548c1d598e34c4813a09d15cd740b466d055bcb 100644 --- a/backend/internal/service/model_pricing_resolver_test.go +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP return map[int64]string{groupID: "anthropic"}, nil }, } - cs := NewChannelService(repo, nil) + cs := NewChannelService(repo, nil, nil, nil) bs := newTestBillingServiceForResolver() return NewModelPricingResolver(cs, bs) } @@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) { return nil, errors.New("database unavailable") }, } - cs := NewChannelService(repo, nil) + cs := NewChannelService(repo, nil, nil, nil) bs := newTestBillingServiceForResolver() r := NewModelPricingResolver(cs, bs) diff --git a/backend/internal/service/openai_403_counter.go b/backend/internal/service/openai_403_counter.go new file mode 100644 index 0000000000000000000000000000000000000000..5ba3e195ec6d661aee8cd5c2e725ea7db597a502 --- /dev/null +++ b/backend/internal/service/openai_403_counter.go @@ -0,0 +1,11 @@ +package service + +import "context" + +// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。 +type OpenAI403CounterCache interface { + // IncrementOpenAI403Count 原子递增 403 计数并返回当前值。 + IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) + // ResetOpenAI403Count 成功后清零计数器。 + ResetOpenAI403Count(ctx context.Context, accountID int64) error +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354a1eb3b926dafa15ecd3cc1e4a37c11cc..808f1229ebea16fce0a4857b14d06ed4bbc8bb7a 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -13,22 +13,39 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/sync/singleflight" ) const ( openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerLoadBalance = "load_balance" + openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled" +) + +const ( + openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second + openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second ) +type cachedOpenAIAdvancedSchedulerSetting struct { + enabled bool + expiresAt int64 +} + +var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting +var openAIAdvancedSchedulerSettingSF singleflight.Group + type OpenAIAccountScheduleRequest struct { - GroupID *int64 - SessionHash string - StickyAccountID int64 - PreviousResponseID string - RequestedModel string - RequiredTransport OpenAIUpstreamTransport - ExcludedIDs map[int64]struct{} + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + RequiredImageCapability OpenAIImagesCapability + ExcludedIDs map[int64]struct{} } type OpenAIAccountScheduleDecision struct { @@ -324,7 +341,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + if !s.isAccountRequestCompatible(account, req) { return nil, nil } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -600,7 +617,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue } - if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + if !s.isAccountRequestCompatible(account, req) { continue } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -706,11 +723,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) @@ -733,7 +750,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } return &AccountSelectionResult{ @@ -751,14 +768,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { - // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { return true } - if s == nil || s.service == nil || account == nil { + if s == nil || s.service == nil { return false } - return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport + return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) +} + +func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool { + if account == nil { + return false + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return false + } + return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) } func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { @@ -805,10 +831,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler return snapshot } -func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { +func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository { + if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil { + return nil + } + return s.rateLimitService.settingService.settingRepo +} + +func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled + } + } + + result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled, nil + } + } + + enabled := false + if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil { + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout) + defer cancel() + + value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey) + if err == nil { + enabled = strings.EqualFold(strings.TrimSpace(value), "true") + } + } + + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: enabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + return enabled, nil + }) + + enabled, _ := result.(bool) + return enabled +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler { if s == nil { return nil } + if !s.isOpenAIAdvancedSchedulerEnabled(ctx) { + return nil + } s.openaiSchedulerOnce.Do(func() { if s.openaiAccountStats == nil { s.openaiAccountStats = newOpenAIAccountRuntimeStats() @@ -820,6 +892,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule return s.openaiScheduler } +func resetOpenAIAdvancedSchedulerSettingCacheForTest() { + openAIAdvancedSchedulerSettingCache = atomic.Value{} + openAIAdvancedSchedulerSettingSF = singleflight.Group{} +} + func (s *OpenAIGatewayService) SelectAccountWithScheduler( ctx context.Context, groupID *int64, @@ -828,13 +905,92 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requestedModel string, excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "") +} + +func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredCapability OpenAIImagesCapability, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + if err == nil && selection != nil && selection.Account != nil { + return selection, decision, nil + } + // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) + if requiredCapability == OpenAIImagesCapabilityNative { + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic) + } + return selection, decision, err +} + +func (s *OpenAIGatewayService) selectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, + requiredImageCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance - return selection, decision, err + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) + for { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + if err != nil { + return nil, decision, err + } + if selection == nil || selection.Account == nil { + return selection, decision, nil + } + if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) { + return selection, decision, nil + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if effectiveExcludedIDs == nil { + effectiveExcludedIDs = make(map[int64]struct{}) + } + if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists { + return nil, decision, ErrNoAvailableAccounts + } + effectiveExcludedIDs[selection.Account.ID] = struct{}{} + } + } + + effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) + for { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + if err != nil { + return nil, decision, err + } + if selection == nil || selection.Account == nil { + return selection, decision, nil + } + if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) { + return selection, decision, nil + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if effectiveExcludedIDs == nil { + effectiveExcludedIDs = make(map[int64]struct{}) + } + if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists { + return nil, decision, ErrNoAvailableAccounts + } + effectiveExcludedIDs[selection.Account.ID] = struct{}{} + } } var stickyAccountID int64 @@ -845,18 +1001,40 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( } return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ - GroupID: groupID, - SessionHash: sessionHash, - StickyAccountID: stickyAccountID, - PreviousResponseID: previousResponseID, - RequestedModel: requestedModel, - RequiredTransport: requiredTransport, - ExcludedIDs: excludedIDs, + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + RequiredImageCapability: requiredImageCapability, + ExcludedIDs: excludedIDs, }) } +func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} { + if len(excludedIDs) == 0 { + return nil + } + cloned := make(map[int64]struct{}, len(excludedIDs)) + for id := range excludedIDs { + cloned[id] = struct{}{} + } + return cloned +} + +func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || account == nil { + return false + } + return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -864,7 +1042,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -872,7 +1050,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { } func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return OpenAIAccountSchedulerMetricsSnapshot{} } diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed40ae239a0dc4c4dc0dd56ad6f61ce11d..b02370cb5ffd28c9caf9c57bf4c66e1559b4518a 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "math" "sync" @@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { accountsByID map[int64]*Account } +type schedulerTestOpenAIAccountRepo struct { + AccountRepository + accounts []Account +} + +func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + +type schedulerTestConcurrencyCache struct { + ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool +} + +func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } + return true, nil +} + +func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} + +func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } + out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } + for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } + out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return out, nil +} + +func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type schedulerTestGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + +func newSchedulerTestOpenAIWSV2Config() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} + +type openAIAdvancedSchedulerSettingRepoStub struct { + values map[string]string +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s == nil || s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected call to Set") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected call to GetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected call to SetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected call to GetAll") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected call to Delete") +} + +func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + repo := &openAIAdvancedSchedulerSettingRepoStub{ + values: map[string]string{}, + } + if enabled != "" { + repo.values[openAIAdvancedSchedulerSettingKey] = enabled + } + return &RateLimitService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { if len(s.snapshotAccounts) == 0 { return nil, false, nil @@ -45,6 +242,230 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 return &cloned, nil } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10106) + accounts := []Account{ + { + ID: 36001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + }, + { + ID: 36002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour)) + require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_disabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10108) + accounts := []Account{ + { + ID: 36011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 36012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36012), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10109) + accounts := []Account{ + { + ID: 36021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.ErrorContains(t, err, "no available OpenAI accounts") + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10107) + accounts := []Account{ + { + ID: 37001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour)) + require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_enabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { ctx := context.Background() groupID := int64(10101) @@ -53,10 +474,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, + cache: cache, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) @@ -76,7 +504,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + } account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) require.NoError(t, err) @@ -92,18 +525,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} snapshotCache := &openAISnapshotCacheStub{ snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, cache: cache, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -128,8 +562,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, } @@ -153,7 +588,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "openai_apikey_responses_websockets_v2_enabled": true, }, } - cache := &stubGatewayCache{} + cache := &schedulerTestGatewayCache{} cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true @@ -163,10 +598,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: cfg, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } store := svc.getOpenAIWSStateStore() @@ -204,17 +640,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_abc": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -260,7 +697,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS Priority: 9, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_sticky_busy": 21001, }, @@ -273,7 +710,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ acquireResults: map[int64]bool{ 21001: false, // sticky 账号已满 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) @@ -288,9 +725,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -328,17 +766,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "openai_ws_force_http": true, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_force_http": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -387,15 +826,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick }, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_ws_only": 2201, }, } - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, @@ -403,9 +842,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -445,10 +885,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, - cfg: newOpenAIWSV2TestConfig(), - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: newSchedulerTestOpenAIWSV2Config(), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -507,7 +948,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, @@ -520,9 +961,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -559,16 +1001,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_metrics": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -749,7 +1192,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, @@ -757,9 +1200,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -905,12 +1349,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { } func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + svc := &OpenAIGatewayService{} ttft := 120 svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() - require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) @@ -947,7 +1393,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() scheduler.service = &OpenAIGatewayService{cfg: cfg} account := &Account{ ID: 8801, diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index c5de8203412ead756d693f12819f3c0d53b4a9ed..ddafc6eb76d5dbb9afcc9bf8ad2c74b3d474211a 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a266d6a01c0ccde93e95a32c7b27500f666fdee4..14abde9b3c6838023f9a8040ef3ae67cf522488a 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -6,9 +6,9 @@ import ( ) var codexModelMap = map[string]string{ + "gpt-5.5": "gpt-5.5", "gpt-5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", "gpt-5.4-none": "gpt-5.4", "gpt-5.4-low": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4", @@ -22,52 +22,21 @@ var codexModelMap = map[string]string{ "gpt-5.3-high": "gpt-5.3-codex", "gpt-5.3-xhigh": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-low": "gpt-5.3-codex", - "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3-codex-low": "gpt-5.3-codex", "gpt-5.3-codex-medium": "gpt-5.3-codex", "gpt-5.3-codex-high": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", "gpt-5.2": "gpt-5.2", "gpt-5.2-none": "gpt-5.2", "gpt-5.2-low": "gpt-5.2", "gpt-5.2-medium": "gpt-5.2", "gpt-5.2-high": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -76,6 +45,11 @@ type codexTransformResult struct { PromptCacheKey string } +const ( + codexImageGenerationBridgeMarker = "" + codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n" +) + func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -219,8 +193,12 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } func normalizeCodexModel(model string) string { + model = strings.TrimSpace(model) if model == "" { - return "gpt-5.1" + return "gpt-5.4" + } + if isOpenAIImageGenerationModel(model) { + return model } modelID := model @@ -235,52 +213,238 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") { + return "gpt-5.5" + } if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { return "gpt-5.4-mini" } - if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") { - return "gpt-5.4-nano" - } if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return "gpt-5.4" } - if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { - return "gpt-5.2-codex" - } if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { return "gpt-5.2" } + if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") { + return "gpt-5.3-codex-spark" + } if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { return "gpt-5.3-codex" } - if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { - return "gpt-5.1-codex-max" + if strings.Contains(normalized, "codex") { + return "gpt-5.3-codex" } - if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") { - return "gpt-5.1-codex-mini" + if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { + return "gpt-5.4" } - if strings.Contains(normalized, "codex-mini-latest") || - strings.Contains(normalized, "gpt-5-codex-mini") || - strings.Contains(normalized, "gpt 5 codex mini") { - return "codex-mini-latest" + + return "gpt-5.4" +} + +func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false } - if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") { - return "gpt-5.1-codex" + tools, ok := rawTools.([]any) + if !ok { + return false } - if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") { - return "gpt-5.1" + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + return true + } } - if strings.Contains(normalized, "codex") { - return "gpt-5.1-codex" + return false +} + +func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false } - if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { - return "gpt-5.1" + tools, ok := rawTools.([]any) + if !ok { + return false + } + + modified := false + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" { + continue + } + if _, ok := toolMap["output_format"]; !ok { + if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" { + toolMap["output_format"] = value + modified = true + } + } + if _, ok := toolMap["output_compression"]; !ok { + if value, exists := toolMap["compression"]; exists && value != nil { + toolMap["output_compression"] = value + modified = true + } + } + if _, ok := toolMap["format"]; ok { + delete(toolMap, "format") + modified = true + } + if _, ok := toolMap["compression"]; ok { + delete(toolMap, "compression") + modified = true + } + } + return modified +} + +func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + + tool := map[string]any{ + "type": "image_generation", + "output_format": "png", + } + + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + reqBody["tools"] = []any{tool} + return true + } + + tools, ok := rawTools.([]any) + if !ok { + reqBody["tools"] = []any{tool} + return true + } + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + return false + } + } + + reqBody["tools"] = append(tools, tool) + return true +} + +func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { + if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) { + return false } - return "gpt-5.1" + existing, _ := reqBody["instructions"].(string) + if strings.Contains(existing, codexImageGenerationBridgeMarker) { + return false + } + + existing = strings.TrimRight(existing, " \t\r\n") + if strings.TrimSpace(existing) == "" { + reqBody["instructions"] = codexImageGenerationBridgeText + return true + } + + reqBody["instructions"] = existing + "\n\n" + codexImageGenerationBridgeText + return true +} + +func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { + if !hasOpenAIImageGenerationTool(reqBody) { + return nil + } + model = strings.TrimSpace(model) + if !isOpenAIImageGenerationModel(model) { + return nil + } + return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model) +} + +func normalizeOpenAIResponsesImageOnlyModel(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + imageModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"])) + if !isOpenAIImageGenerationModel(imageModel) { + return false + } + + modified := false + tools, _ := reqBody["tools"].([]any) + imageToolIndex := -1 + for i, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + imageToolIndex = i + break + } + } + if imageToolIndex < 0 { + tools = append(tools, map[string]any{ + "type": "image_generation", + "model": imageModel, + }) + imageToolIndex = len(tools) - 1 + reqBody["tools"] = tools + modified = true + } + + if toolMap, ok := tools[imageToolIndex].(map[string]any); ok { + if strings.TrimSpace(firstNonEmptyString(toolMap["model"])) == "" { + toolMap["model"] = imageModel + modified = true + } + for _, key := range []string{ + "size", + "quality", + "background", + "output_format", + "output_compression", + "moderation", + "style", + "partial_images", + } { + if value, exists := reqBody[key]; exists && value != nil { + if _, toolHas := toolMap[key]; !toolHas { + toolMap[key] = value + } + delete(reqBody, key) + modified = true + } + } + } + + if prompt := strings.TrimSpace(firstNonEmptyString(reqBody["prompt"])); prompt != "" { + if _, hasInput := reqBody["input"]; !hasInput { + reqBody["input"] = prompt + } + delete(reqBody, "prompt") + modified = true + } + + if _, ok := reqBody["tool_choice"]; !ok { + reqBody["tool_choice"] = map[string]any{"type": "image_generation"} + modified = true + } + if imageModel != openAIImagesResponsesMainModel { + modified = true + } + reqBody["model"] = openAIImagesResponsesMainModel + return modified } func normalizeOpenAIModelForUpstream(account *Account, model string) string { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 993ade0747f15d1239bd8a77497a9bc033691dfc..4fd16fdb4af14ee0ac407e5fcceb4a8627ec393d 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -217,6 +217,195 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction require.Equal(t, "bash", first["name"]) } +func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) { + reqBody := map[string]any{ + "tools": []any{ + map[string]any{ + "type": "image_generation", + "format": "png", + "compression": 60, + }, + }, + } + + modified := normalizeOpenAIResponsesImageGenerationTools(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "png", first["output_format"]) + require.Equal(t, 60, first["output_compression"]) + _, hasFormat := first["format"] + require.False(t, hasFormat) + _, hasCompression := first["compression"] + require.False(t, hasCompression) +} + +func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": "draw a cat", + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", tool["type"]) + require.Equal(t, "png", tool["output_format"]) +} + +func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "web_search"}, + }, + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 2) + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "web_search", first["type"]) + second, ok := tools[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", second["type"]) + require.Equal(t, "png", second["output_format"]) +} + +func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "webp"}, + map[string]any{"type": "web_search"}, + }, + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.False(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 2) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "webp", tool["output_format"]) +} + +func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) { + reqBody := map[string]any{ + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "png"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.True(t, modified) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Contains(t, instructions, "existing instructions") + require.Contains(t, instructions, codexImageGenerationBridgeMarker) + require.Contains(t, instructions, "Responses native `image_generation` tool") + + modified = applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) +} + +func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) { + reqBody := map[string]any{ + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "web_search"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) + require.Equal(t, "existing instructions", reqBody["instructions"]) +} + +func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-image-2", + "prompt": "draw a cat", + "size": "1024x1024", + "output_format": "png", + } + + modified := normalizeOpenAIResponsesImageOnlyModel(reqBody) + require.True(t, modified) + require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"]) + require.Equal(t, "draw a cat", reqBody["input"]) + _, hasPrompt := reqBody["prompt"] + require.False(t, hasPrompt) + _, hasTopLevelSize := reqBody["size"] + require.False(t, hasTopLevelSize) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", tool["type"]) + require.Equal(t, "gpt-image-2", tool["model"]) + require.Equal(t, "1024x1024", tool["size"]) + require.Equal(t, "png", tool["output_format"]) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", choice["type"]) +} + +func TestNormalizeOpenAIResponsesImageOnlyModel_PreservesExistingImageTool(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-image-2", + "input": "draw a cat", + "tools": []any{ + map[string]any{ + "type": "image_generation", + "model": "gpt-image-1.5", + }, + }, + "tool_choice": "auto", + } + + modified := normalizeOpenAIResponsesImageOnlyModel(reqBody) + require.True(t, modified) + require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"]) + require.Equal(t, "auto", reqBody["tool_choice"]) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "gpt-image-1.5", tool["model"]) +} + +func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) { + err := validateOpenAIResponsesImageModel(map[string]any{ + "tools": []any{ + map[string]any{"type": "image_generation"}, + }, + }, "gpt-image-2") + + require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`) +} + func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 @@ -240,15 +429,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt 5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", - "gpt 5.4 nano": "gpt-5.4-nano", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt 5.3 codex spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt 5.3 codex": "gpt-5.3-codex", } @@ -257,6 +444,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) { + cases := map[string]string{ + "": "gpt-5.4", + "gpt-5": "gpt-5.4", + "gpt-5-mini": "gpt-5.4", + "gpt-5-nano": "gpt-5.4", + "gpt-5.1": "gpt-5.4", + "gpt-5.1-codex": "gpt-5.3-codex", + "gpt-5.1-codex-max": "gpt-5.3-codex", + "gpt-5.1-codex-mini": "gpt-5.3-codex", + "gpt-5.2-codex": "gpt-5.2", + "codex-mini-latest": "gpt-5.3-codex", + "gpt-5-codex": "gpt-5.3-codex", + } + + for input, expected := range cases { + require.Equal(t, expected, normalizeCodexModel(input)) + } +} + func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 88e16a4db0f9d5ec57e815532c7a6fa11be71d7e..fcd27f1921b9af90d40d0804f432945b63f20ee4 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,14 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch normalizeCodexModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex": + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel + // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + switch normalizeCodexModel(trimmed) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return true default: return false diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c6805464955ddc90cf120104aae9259b90bd32e9 --- /dev/null +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type openAI403CounterResetStub struct { + resetCalls []int64 +} + +func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) { + return 0, nil +} + +func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { + counter := &openAI403CounterResetStub{} + rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) + rateLimitSvc.SetOpenAI403CounterCache(counter) + + svc := &OpenAIGatewayService{ + rateLimitService: rateLimitSvc, + } + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{}, + Account: &Account{ID: 777, Platform: PlatformOpenAI}, + }) + + require.NoError(t, err) + require.Equal(t, []int64{777}, counter.resetCalls) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index ac7d28a7f0ddbd33642bdbac610fc00a03249337..663066a35d9ce7381b841180a3408fef390c013b 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( responsesBody = stripped } } + responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody) + if err != nil { + return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err) + } // Minimal stub populated from the raw body so downstream billing // propagation (ServiceTier, ReasoningEffort) keeps working. responsesReq = &apicompat.ResponsesRequest{ Model: upstreamModel, - ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(), + ServiceTier: normalizedServiceTier, } if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" { responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort} @@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("convert chat completions to responses: %w", err) } responsesReq.Model = upstreamModel + normalizeResponsesRequestServiceTier(responsesReq) responsesBody, err = json.Marshal(responsesReq) if err != nil { return nil, fmt.Errorf("marshal responses request: %w", err) @@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return result, handleErr } +func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) { + if req == nil { + return + } + req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier) +} + +func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) { + if len(body) == 0 { + return body, "", nil + } + rawServiceTier := gjson.GetBytes(body, "service_tier").String() + if rawServiceTier == "" { + return body, "", nil + } + normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier) + if normalizedServiceTier == "" { + trimmed, err := sjson.DeleteBytes(body, "service_tier") + return trimmed, "", err + } + if normalizedServiceTier == rawServiceTier { + return body, normalizedServiceTier, nil + } + trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier) + return trimmed, normalizedServiceTier, err +} + +func normalizedOpenAIServiceTierValue(raw string) string { + normalized := normalizeOpenAIServiceTier(raw) + if normalized == nil { + return "" + } + return *normalized +} + // handleChatCompletionsErrorResponse reads an upstream error and returns it in // OpenAI Chat Completions error format. func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a00fb71cab773af64459759cf246cd0a7ce841b5 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeResponsesRequestServiceTier(t *testing.T) { + t.Parallel() + + req := &apicompat.ResponsesRequest{ServiceTier: " fast "} + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "priority", req.ServiceTier) + + req.ServiceTier = "flex" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "flex", req.ServiceTier) + + req.ServiceTier = "default" + normalizeResponsesRequestServiceTier(req) + require.Empty(t, req.ServiceTier) +} + +func TestNormalizeResponsesBodyServiceTier(t *testing.T) { + t.Parallel() + + body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`)) + require.NoError(t, err) + require.Equal(t, "priority", tier) + require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`)) + require.NoError(t, err) + require.Equal(t, "flex", tier) + require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) + require.NoError(t, err) + require.Empty(t, tier) + require.False(t, gjson.GetBytes(body, "service_tier").Exists()) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e6fa94aa3b2c61754daadf84c5bad722bbcae430..9665c4c8b30b8836710a17e069f5b5d6bd881f73 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel Model: "gpt-5.1", Duration: time.Second, }, - APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}}, User: &User{ID: 200}, Account: &Account{ID: 300}, Subscription: subscription, @@ -1070,3 +1070,78 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, subRepo.incrementCalls) } + +func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_only_usage", + Model: "gpt-image-2", + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1007}, + User: &User{ID: 2007}, + Account: &Account{ID: 3007}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.ImageSize) + require.Equal(t, "1K", *usageRepo.lastLog.ImageSize) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) { + imagePrice := 0.02 + groupID := int64(12) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_per_request", + Model: "gpt-image-2", + Usage: OpenAIUsage{ + InputTokens: 1110, + OutputTokens: 1756, + ImageOutputTokens: 1756, + }, + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 2008}, + Account: &Account{ID: 3008}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 064191bd26aa454f0a33b05258bedd6e800f783b..d99cd7dabff088ce3d27fda9cf4e6c4555cec3da 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -233,6 +233,8 @@ type OpenAIForwardResult struct { ResponseHeaders http.Header Duration time.Duration FirstTokenMs *int + ImageCount int + ImageSize string } type OpenAIWSRetryMetricsSnapshot struct { @@ -1933,6 +1935,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } + if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") + } + + if normalizeOpenAIResponsesImageGenerationTools(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") + } + if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") + } + // 对所有请求执行模型映射(包含 Codex CLI)。 billingModel := account.GetMappedModel(reqModel) if billingModel != reqModel { @@ -1942,6 +1961,40 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel + if normalizeOpenAIResponsesImageOnlyModel(reqBody) { + bodyModified = true + disablePatch() + if model, ok := reqBody["model"].(string); ok { + upstreamModel = strings.TrimSpace(model) + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Normalized /responses image-only model request inbound_model=%s image_model=%s upstream_model=%s", + reqModel, + billingModel, + upstreamModel, + ) + } + if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "model", + }, + }) + return nil, err + } + if hasOpenAIImageGenerationTool(reqBody) { + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s", + reqModel, + upstreamModel, + account.Type, + ) + } // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, @@ -3889,6 +3942,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) + usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int()) } func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { @@ -3900,11 +3954,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { "usage.input_tokens", "usage.output_tokens", "usage.input_tokens_details.cached_tokens", + "usage.output_tokens_details.image_tokens", ) return OpenAIUsage{ InputTokens: int(values[0].Int()), OutputTokens: int(values[1].Int()), CacheReadInputTokens: int(values[2].Int()), + ImageOutputTokens: int(values[3].Int()), }, true } @@ -4087,22 +4143,39 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { // Returns (nil, false) if no content was found in deltas. func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { acc := apicompat.NewBufferedResponseAccumulator() + imageOutputs := make([]json.RawMessage, 0, 1) + seenImages := make(map[string]struct{}) lines := strings.Split(bodyText, "\n") for _, line := range lines { data, ok := extractOpenAISSEDataLine(line) if !ok || data == "" || data == "[DONE]" { continue } + if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok { + imageOutputs = append(imageOutputs, imageOutput) + } var event apicompat.ResponsesStreamEvent if err := json.Unmarshal([]byte(data), &event); err != nil { continue } acc.ProcessEvent(&event) } - if !acc.HasContent() { + if !acc.HasContent() && len(imageOutputs) == 0 { return nil, false } - output := acc.BuildOutput() + + var output []json.RawMessage + if acc.HasContent() { + outputJSON, err := json.Marshal(acc.BuildOutput()) + if err == nil { + _ = json.Unmarshal(outputJSON, &output) + } + } + output = append(output, imageOutputs...) + if len(output) == 0 { + return nil, false + } + outputJSON, err := json.Marshal(output) if err != nil { return nil, false @@ -4110,6 +4183,33 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { return outputJSON, true } +func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct{}) (json.RawMessage, bool) { + if len(data) == 0 || !gjson.ValidBytes(data) { + return nil, false + } + if gjson.GetBytes(data, "type").String() != "response.output_item.done" { + return nil, false + } + item := gjson.GetBytes(data, "item") + if !item.Exists() || !item.IsObject() || item.Get("type").String() != "image_generation_call" { + return nil, false + } + if strings.TrimSpace(item.Get("result").String()) == "" { + return nil, false + } + key := strings.TrimSpace(item.Get("id").String()) + if key == "" { + key = strings.TrimSpace(item.Get("output_format").String()) + "|" + strings.TrimSpace(item.Get("result").String()) + } + if key != "" && seen != nil { + if _, exists := seen[key]; exists { + return nil, false + } + seen[key] = struct{}{} + } + return json.RawMessage(item.Raw), true +} + func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") @@ -4394,10 +4494,14 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result + if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { + s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) + } // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && - result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 { + result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && + result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { return nil } @@ -4451,21 +4555,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - ServiceTier: serviceTier, - Resolver: s.resolver, - }) - } else { - cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) - } + cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -4505,6 +4595,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens, + ImageCount: result.ImageCount, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), } if cost != nil { usageLog.InputCost = cost.InputCost @@ -4530,6 +4622,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if cost != nil && cost.BillingMode != "" { billingMode := cost.BillingMode usageLog.BillingMode = &billingMode + } else if result.ImageCount > 0 { + billingMode := string(BillingModeImage) + usageLog.BillingMode = &billingMode } else { billingMode := string(BillingModeToken) usageLog.BillingMode = &billingMode @@ -4589,6 +4684,83 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec return nil } +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( + ctx context.Context, + result *OpenAIForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + if result != nil && result.ImageCount > 0 { + return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil + } + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + return s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } + return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) +} + +func (s *OpenAIGatewayService) calculateOpenAIImageCost( + ctx context.Context, + billingModel string, + apiKey *APIKey, + result *OpenAIForwardResult, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && + (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + RequestCount: 1, + SizeTier: result.ImageSize, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err == nil { + return cost + } + logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err) + } + + var groupConfig *ImagePriceConfig + if apiKey != nil && apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. // Exported for use in ratelimit_service when handling OpenAI 429 responses. func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index cf2d875fcd8d015d285253b2fb276feadf984915..ed7c78a3f73c948422a149a869f31727570284f8 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -18,6 +18,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) // 编译期接口断言 @@ -1880,6 +1881,33 @@ func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { require.NotContains(t, rec.Body.String(), "data:") } +func TestHandleSSEToJSON_ReconstructsImageGenerationOutputItemDone(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","result":"aGVsbG8=","revised_prompt":"draw a cat","output_format":"png"}}`, + `data: {"type":"response.completed","response":{"id":"resp_img","model":"gpt-5.4","output":[],"usage":{"input_tokens":7,"output_tokens":9,"output_tokens_details":{"image_tokens":4}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-5.4", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 4, usage.ImageOutputTokens) + require.NotContains(t, rec.Body.String(), "data:") + require.Equal(t, "image_generation_call", gjson.Get(rec.Body.String(), "output.0.type").String()) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "output.0.result").String()) + require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "output.0.revised_prompt").String()) +} + func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go new file mode 100644 index 0000000000000000000000000000000000000000..4badcb1c2cbaa55b917104ec147a86eae4ce0435 --- /dev/null +++ b/backend/internal/service/openai_images.go @@ -0,0 +1,1346 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + openAIImagesGenerationsEndpoint = "/v1/images/generations" + openAIImagesEditsEndpoint = "/v1/images/edits" + + openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations" + openAIImagesEditsURL = "https://api.openai.com/v1/images/edits" + + openAIChatGPTStartURL = "https://chatgpt.com/" + openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files" + openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download + openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part + openAIImagesResponsesMainModel = "gpt-5.4-mini" +) + +type OpenAIImagesCapability string + +const ( + OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic" + OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native" +) + +type OpenAIImagesUpload struct { + FieldName string + FileName string + ContentType string + Data []byte + Width int + Height int +} + +type OpenAIImagesRequest struct { + Endpoint string + ContentType string + Multipart bool + Model string + ExplicitModel bool + Prompt string + Stream bool + N int + Size string + ExplicitSize bool + SizeTier string + ResponseFormat string + Quality string + Background string + OutputFormat string + Moderation string + InputFidelity string + Style string + OutputCompression *int + PartialImages *int + HasMask bool + HasNativeOptions bool + RequiredCapability OpenAIImagesCapability + InputImageURLs []string + MaskImageURL string + Uploads []OpenAIImagesUpload + MaskUpload *OpenAIImagesUpload + Body []byte + bodyHash string +} + +func (r *OpenAIImagesRequest) IsEdits() bool { + return r != nil && r.Endpoint == openAIImagesEditsEndpoint +} + +func (r *OpenAIImagesRequest) StickySessionSeed() string { + if r == nil { + return "" + } + parts := []string{ + "openai-images", + strings.TrimSpace(r.Endpoint), + strings.TrimSpace(r.Model), + strings.TrimSpace(r.Size), + strings.TrimSpace(r.Prompt), + } + seed := strings.Join(parts, "|") + if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" { + seed += "|body=" + r.bodyHash + } + return seed +} + +func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) { + if c == nil || c.Request == nil { + return nil, fmt.Errorf("missing request context") + } + endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path) + if endpoint == "" { + return nil, fmt.Errorf("unsupported images endpoint") + } + + contentType := strings.TrimSpace(c.GetHeader("Content-Type")) + req := &OpenAIImagesRequest{ + Endpoint: endpoint, + ContentType: contentType, + N: 1, + Body: body, + } + if len(body) > 0 { + sum := sha256.Sum256(body) + req.bodyHash = hex.EncodeToString(sum[:8]) + } + + mediaType, _, err := mime.ParseMediaType(contentType) + if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { + req.Multipart = true + if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil { + return nil, parseErr + } + } else { + if len(body) == 0 { + return nil, fmt.Errorf("request body is empty") + } + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("failed to parse request body") + } + if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil { + return nil, parseErr + } + } + + applyOpenAIImagesDefaults(req) + if err := validateOpenAIImagesModel(req.Model); err != nil { + return nil, err + } + req.SizeTier = normalizeOpenAIImageSizeTier(req.Size) + req.RequiredCapability = classifyOpenAIImagesCapability(req) + return req, nil +} + +func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { + if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() { + req.Model = strings.TrimSpace(modelResult.String()) + req.ExplicitModel = req.Model != "" + } + req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String()) + + if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() { + if streamResult.Type != gjson.True && streamResult.Type != gjson.False { + return fmt.Errorf("invalid stream field type") + } + req.Stream = streamResult.Bool() + } + + if nResult := gjson.GetBytes(body, "n"); nResult.Exists() { + if nResult.Type != gjson.Number { + return fmt.Errorf("invalid n field type") + } + req.N = int(nResult.Int()) + if req.N <= 0 { + return fmt.Errorf("n must be greater than 0") + } + } + + if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() { + req.Size = strings.TrimSpace(sizeResult.String()) + req.ExplicitSize = req.Size != "" + } + req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) + req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String()) + req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String()) + req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String()) + req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String()) + req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String()) + req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String()) + req.HasMask = gjson.GetBytes(body, "mask").Exists() + if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() { + if outputCompression.Type != gjson.Number { + return fmt.Errorf("invalid output_compression field type") + } + v := int(outputCompression.Int()) + req.OutputCompression = &v + } + if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() { + if partialImages.Type != gjson.Number { + return fmt.Errorf("invalid partial_images field type") + } + v := int(partialImages.Int()) + req.PartialImages = &v + } + if req.IsEdits() { + images := gjson.GetBytes(body, "images") + if images.Exists() { + if !images.IsArray() { + return fmt.Errorf("invalid images field type") + } + for _, item := range images.Array() { + if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" { + req.InputImageURLs = append(req.InputImageURLs, imageURL) + continue + } + if item.Get("file_id").Exists() { + return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)") + } + } + } + if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" { + req.MaskImageURL = maskImageURL + req.HasMask = true + } + if gjson.GetBytes(body, "mask.file_id").Exists() { + return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)") + } + if len(req.InputImageURLs) == 0 { + return fmt.Errorf("images[].image_url is required") + } + } + req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { + return gjson.GetBytes(body, path).Exists() + }) + return nil +} + +func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid multipart content-type: %w", err) + } + boundary := strings.TrimSpace(params["boundary"]) + if boundary == "" { + return fmt.Errorf("multipart boundary is required") + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("read multipart body: %w", err) + } + name := strings.TrimSpace(part.FormName()) + if name == "" { + _ = part.Close() + continue + } + + data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize)) + _ = part.Close() + if err != nil { + return fmt.Errorf("read multipart field %s: %w", name, err) + } + + fileName := strings.TrimSpace(part.FileName()) + if fileName != "" { + partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) + if name == "mask" && len(data) > 0 { + req.HasMask = true + width, height := parseOpenAIImageDimensions(part.Header) + maskUpload := OpenAIImagesUpload{ + FieldName: name, + FileName: fileName, + ContentType: partContentType, + Data: data, + Width: width, + Height: height, + } + req.MaskUpload = &maskUpload + } + if name == "image" || strings.HasPrefix(name, "image[") { + width, height := parseOpenAIImageDimensions(part.Header) + req.Uploads = append(req.Uploads, OpenAIImagesUpload{ + FieldName: name, + FileName: fileName, + ContentType: partContentType, + Data: data, + Width: width, + Height: height, + }) + } + continue + } + + value := strings.TrimSpace(string(data)) + switch name { + case "model": + req.Model = value + req.ExplicitModel = value != "" + case "prompt": + req.Prompt = value + case "size": + req.Size = value + req.ExplicitSize = value != "" + case "response_format": + req.ResponseFormat = strings.ToLower(value) + case "stream": + parsed, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid stream field value") + } + req.Stream = parsed + case "n": + n, err := strconv.Atoi(value) + if err != nil || n <= 0 { + return fmt.Errorf("n must be a positive integer") + } + req.N = n + case "quality": + req.Quality = value + req.HasNativeOptions = true + case "background": + req.Background = value + req.HasNativeOptions = true + case "output_format": + req.OutputFormat = value + req.HasNativeOptions = true + case "moderation": + req.Moderation = value + req.HasNativeOptions = true + case "input_fidelity": + req.InputFidelity = value + req.HasNativeOptions = true + case "style": + req.Style = value + req.HasNativeOptions = true + case "output_compression": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid output_compression field value") + } + req.OutputCompression = &n + req.HasNativeOptions = true + case "partial_images": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid partial_images field value") + } + req.PartialImages = &n + req.HasNativeOptions = true + default: + if isOpenAINativeImageOption(name) && value != "" { + req.HasNativeOptions = true + } + } + } + + if len(req.Uploads) == 0 && req.IsEdits() { + return fmt.Errorf("image file is required") + } + return nil +} + +func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) { + return 0, 0 +} + +func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) { + if req == nil { + return + } + if req.N <= 0 { + req.N = 1 + } + if strings.TrimSpace(req.Model) != "" { + req.Model = strings.TrimSpace(req.Model) + return + } + req.Model = "gpt-image-2" +} + +func isOpenAIImageGenerationModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-") +} + +func validateOpenAIImagesModel(model string) error { + model = strings.TrimSpace(model) + if isOpenAIImageGenerationModel(model) { + return nil + } + if model == "" { + return fmt.Errorf("images endpoint requires an image model") + } + return fmt.Errorf("images endpoint requires an image model, got %q", model) +} + +func normalizeOpenAIImagesEndpointPath(path string) string { + trimmed := strings.TrimSpace(path) + switch { + case strings.Contains(trimmed, "/images/generations"): + return openAIImagesGenerationsEndpoint + case strings.Contains(trimmed, "/images/edits"): + return openAIImagesEditsEndpoint + default: + return "" + } +} + +func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability { + if req == nil { + return OpenAIImagesCapabilityNative + } + if req.ExplicitModel || req.ExplicitSize { + return OpenAIImagesCapabilityNative + } + model := strings.ToLower(strings.TrimSpace(req.Model)) + if !strings.HasPrefix(model, "gpt-image-") { + return OpenAIImagesCapabilityNative + } + if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions { + return OpenAIImagesCapabilityNative + } + if req.IsEdits() && !req.Multipart { + return OpenAIImagesCapabilityNative + } + if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" { + return OpenAIImagesCapabilityNative + } + return OpenAIImagesCapabilityBasic +} + +func hasOpenAINativeImageOptions(exists func(path string) bool) bool { + for _, path := range []string{ + "background", + "quality", + "style", + "output_format", + "output_compression", + "moderation", + "input_fidelity", + "partial_images", + } { + if exists(path) { + return true + } + } + return false +} + +func isOpenAINativeImageOption(name string) bool { + switch strings.TrimSpace(strings.ToLower(name)) { + case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images": + return true + default: + return false + } +} + +func normalizeOpenAIImageSizeTier(size string) string { + switch strings.ToLower(strings.TrimSpace(size)) { + case "1024x1024": + return "1K" + case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": + return "2K" + default: + return "2K" + } +} + +func (s *OpenAIGatewayService) ForwardImages( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + if parsed == nil { + return nil, fmt.Errorf("parsed images request is required") + } + switch account.Type { + case AccountTypeAPIKey: + return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel) + case AccountTypeOAuth: + return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel) + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + upstreamModel := account.GetMappedModel(requestModel) + if err := validateOpenAIImagesModel(upstreamModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s", + strings.TrimSpace(parsed.Model), + upstreamModel, + parsed.Endpoint, + account.Type, + ) + forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel) + if err != nil { + return nil, err + } + if !parsed.Multipart { + setOpsUpstreamRequestBody(c, forwardBody) + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleErrorResponse(ctx, resp, c, account, forwardBody) + } + defer func() { _ = resp.Body.Close() }() + + var usage OpenAIUsage + imageCount := parsed.N + var firstTokenMs *int + if parsed.Stream { + streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) + if err != nil { + return nil, err + } + usage = streamUsage + imageCount = streamCount + firstTokenMs = ttft + } else { + nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c) + if err != nil { + return nil, err + } + usage = nonStreamUsage + if nonStreamCount > 0 { + imageCount = nonStreamCount + } + } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} + +func (s *OpenAIGatewayService) buildOpenAIImagesRequest( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + contentType string, + token string, + endpoint string, +) (*http.Request, error) { + targetURL := openAIImagesGenerationsURL + if endpoint == openAIImagesEditsEndpoint { + targetURL = openAIImagesEditsURL + } + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIImagesURL(validatedURL, endpoint) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + for key, values := range c.Request.Header { + if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] { + continue + } + for _, value := range values { + req.Header.Add(key, value) + } + } + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("User-Agent", customUA) + } + if strings.TrimSpace(contentType) != "" { + req.Header.Set("Content-Type", contentType) + } + return req, nil +} + +func buildOpenAIImagesURL(base string, endpoint string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1") + if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + relative + } + return normalized + endpoint +} + +func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) { + model = strings.TrimSpace(model) + if model == "" { + return body, contentType, nil + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { + rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model) + return rewrittenBody, rewrittenType, rewriteErr + } + rewritten, err := sjson.SetBytes(body, "model", model) + if err != nil { + return nil, "", fmt.Errorf("rewrite image request model: %w", err) + } + return rewritten, contentType, nil +} + +func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, "", fmt.Errorf("parse multipart content-type: %w", err) + } + boundary := strings.TrimSpace(params["boundary"]) + if boundary == "" { + return nil, "", fmt.Errorf("multipart boundary is required") + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + modelWritten := false + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return nil, "", fmt.Errorf("read multipart body: %w", err) + } + + formName := strings.TrimSpace(part.FormName()) + partHeader := cloneMultipartHeader(part.Header) + target, err := writer.CreatePart(partHeader) + if err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("create multipart part: %w", err) + } + + if formName == "model" && part.FileName() == "" { + if _, err := target.Write([]byte(model)); err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("rewrite multipart model: %w", err) + } + modelWritten = true + _ = part.Close() + continue + } + if _, err := io.Copy(target, part); err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("copy multipart part: %w", err) + } + _ = part.Close() + } + + if !modelWritten { + if err := writer.WriteField("model", model); err != nil { + return nil, "", fmt.Errorf("append multipart model field: %w", err) + } + } + if err := writer.Close(); err != nil { + return nil, "", fmt.Errorf("finalize multipart body: %w", err) + } + return buffer.Bytes(), writer.FormDataContentType(), nil +} + +func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader { + dst := make(textproto.MIMEHeader, len(src)) + for key, values := range src { + copied := make([]string, len(values)) + copy(copied, values) + dst[key] = copied + } + return dst +} + +func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return OpenAIUsage{}, 0, err + } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + c.Data(resp.StatusCode, contentType, body) + + usage, _ := extractOpenAIUsageFromJSONBytes(body) + return usage, extractOpenAIImageCountFromJSONBytes(body), nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( + resp *http.Response, + c *gin.Context, + startTime time.Time, +) (OpenAIUsage, int, *int, error) { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Status(resp.StatusCode) + c.Header("Content-Type", contentType) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + } + + reader := bufio.NewReader(resp.Body) + usage := OpenAIUsage{} + imageCount := 0 + var firstTokenMs *int + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if _, writeErr := c.Writer.Write(line); writeErr != nil { + return OpenAIUsage{}, 0, firstTokenMs, writeErr + } + flusher.Flush() + + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { + dataBytes := []byte(data) + mergeOpenAIUsage(&usage, dataBytes) + if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { + imageCount = count + } + } + } + if err == io.EOF { + break + } + if err != nil { + return OpenAIUsage{}, 0, firstTokenMs, err + } + } + return usage, imageCount, firstTokenMs, nil +} + +func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { + if dst == nil { + return + } + if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok { + if parsed.InputTokens > 0 { + dst.InputTokens = parsed.InputTokens + } + if parsed.OutputTokens > 0 { + dst.OutputTokens = parsed.OutputTokens + } + if parsed.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = parsed.CacheReadInputTokens + } + if parsed.ImageOutputTokens > 0 { + dst.ImageOutputTokens = parsed.ImageOutputTokens + } + } +} + +func extractOpenAIImageCountFromJSONBytes(body []byte) int { + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + data := gjson.GetBytes(body, "data") + if data.Exists() && data.IsArray() { + return len(data.Array()) + } + return 0 +} + +type openAIImagePointerInfo struct { + Pointer string + DownloadURL string + B64JSON string + MimeType string + Prompt string +} + +func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { + if len(body) == 0 { + return nil + } + prompt := "" + for _, path := range []string{ + "message.metadata.dalle.prompt", + "metadata.dalle.prompt", + "revised_prompt", + } { + if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" { + prompt = value + break + } + } + matches := openAIImagePointerMatches(body) + out := make([]openAIImagePointerInfo, 0, len(matches)) + for _, pointer := range matches { + out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt}) + } + return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt)) +} + +func openAIImagePointerMatches(body []byte) []string { + raw := string(body) + matches := make([]string, 0, 4) + for _, prefix := range []string{"file-service://", "sediment://"} { + start := 0 + for { + idx := strings.Index(raw[start:], prefix) + if idx < 0 { + break + } + idx += start + end := idx + len(prefix) + for end < len(raw) { + ch := raw[end] + if ch != '-' && ch != '_' && + (ch < '0' || ch > '9') && + (ch < 'a' || ch > 'z') && + (ch < 'A' || ch > 'Z') { + break + } + end++ + } + matches = append(matches, raw[idx:end]) + start = end + } + } + return dedupeStrings(matches) +} + +func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo { + if len(next) == 0 { + return existing + } + seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next)) + out := make([]openAIImagePointerInfo, 0, len(existing)+len(next)) + for _, item := range existing { + if key := item.identityKey(); key != "" { + seen[key] = item + } + out = append(out, item) + } + for _, item := range next { + key := item.identityKey() + if key == "" { + continue + } + if existingItem, ok := seen[key]; ok { + merged := mergeOpenAIImagePointerInfo(existingItem, item) + if merged != existingItem { + for i := range out { + if out[i].identityKey() == key { + out[i] = merged + break + } + } + seen[key] = merged + } + continue + } + seen[key] = item + out = append(out, item) + } + return out +} + +func (i openAIImagePointerInfo) identityKey() string { + switch { + case strings.TrimSpace(i.Pointer) != "": + return "pointer:" + strings.TrimSpace(i.Pointer) + case strings.TrimSpace(i.DownloadURL) != "": + return "download:" + strings.TrimSpace(i.DownloadURL) + case strings.TrimSpace(i.B64JSON) != "": + b64 := strings.TrimSpace(i.B64JSON) + if len(b64) > 64 { + b64 = b64[:64] + } + return "b64:" + b64 + default: + return "" + } +} + +func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo { + merged := existing + if strings.TrimSpace(merged.Pointer) == "" { + merged.Pointer = next.Pointer + } + if strings.TrimSpace(merged.DownloadURL) == "" { + merged.DownloadURL = next.DownloadURL + } + if strings.TrimSpace(merged.B64JSON) == "" { + merged.B64JSON = next.B64JSON + } + if strings.TrimSpace(merged.MimeType) == "" { + merged.MimeType = next.MimeType + } + if strings.TrimSpace(merged.Prompt) == "" { + merged.Prompt = next.Prompt + } + return merged +} + +func resolveOpenAIImageBytes( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointer openAIImagePointerInfo, +) ([]byte, error) { + if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" { + return base64.StdEncoding.DecodeString(normalized) + } + if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" { + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + } + if strings.TrimSpace(pointer.Pointer) == "" { + return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data") + } + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return nil, err + } + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) +} + +func normalizeOpenAIImageBase64(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(raw), "data:") { + if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) { + raw = raw[idx+1:] + } + } + raw = strings.TrimSpace(raw) + raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4) + if raw == "" { + return "" + } + if _, err := base64.StdEncoding.DecodeString(raw); err != nil { + return "" + } + return raw +} + +func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo { + if len(body) == 0 || !gjson.ValidBytes(body) { + return nil + } + var decoded any + if err := json.Unmarshal(body, &decoded); err != nil { + return nil + } + var out []openAIImagePointerInfo + walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out) + return out +} + +func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) { + switch value := node.(type) { + case map[string]any: + localPrompt := prompt + for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} { + if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" { + localPrompt = strings.TrimSpace(v) + break + } + } + item := openAIImagePointerInfo{ + Prompt: localPrompt, + Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]), + DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]), + B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]), + MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]), + } + switch { + case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"), + strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"), + isLikelyOpenAIImageDownloadURL(item.DownloadURL), + normalizeOpenAIImageBase64(item.B64JSON) != "": + *out = append(*out, item) + } + for _, child := range value { + walkOpenAIImageInlineAssets(child, localPrompt, out) + } + case []any: + for _, child := range value { + walkOpenAIImageInlineAssets(child, prompt, out) + } + } +} + +func firstNonEmptyString(values ...any) string { + for _, value := range values { + if s, ok := value.(string); ok && strings.TrimSpace(s) != "" { + return strings.TrimSpace(s) + } + } + return "" +} + +func isLikelyOpenAIImageDownloadURL(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "data:image/") { + return true + } + if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") { + return false + } + lower := strings.ToLower(raw) + return strings.Contains(lower, "/download") || + strings.Contains(lower, ".png") || + strings.Contains(lower, ".jpg") || + strings.Contains(lower, ".jpeg") || + strings.Contains(lower, ".webp") +} + +func fetchOpenAIImageDownloadURL( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointer string, +) (string, error) { + url := "" + allowConversationRetry := false + switch { + case strings.HasPrefix(pointer, "file-service://"): + fileID := strings.TrimPrefix(pointer, "file-service://") + url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID) + case strings.HasPrefix(pointer, "sediment://"): + attachmentID := strings.TrimPrefix(pointer, "sediment://") + url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID) + allowConversationRetry = true + default: + return "", fmt.Errorf("unsupported image pointer: %s", pointer) + } + + var lastErr error + for attempt := 0; attempt < 8; attempt++ { + var result struct { + DownloadURL string `json:"download_url"` + } + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetSuccessResult(&result). + Get(url) + if err != nil { + lastErr = err + } else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" { + return strings.TrimSpace(result.DownloadURL), nil + } else { + statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed") + if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) { + return "", statusErr + } + lastErr = statusErr + } + if attempt == 7 { + break + } + timer := time.NewTimer(750 * time.Millisecond) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return "", ctx.Err() + case <-timer.C: + } + } + if lastErr == nil { + lastErr = fmt.Errorf("fetch image download url failed") + } + return "", lastErr +} + +func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) { + request := client.R(). + SetContext(ctx). + DisableAutoReadResponse() + + if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) { + downloadHeaders := cloneHTTPHeader(headers) + downloadHeaders.Set("Accept", "image/*,*/*;q=0.8") + downloadHeaders.Del("Content-Type") + request.SetHeaders(headerToMap(downloadHeaders)) + } else { + userAgent := strings.TrimSpace(headers.Get("User-Agent")) + if userAgent == "" { + userAgent = openAIImageBackendUserAgent + } + request.SetHeader("User-Agent", userAgent) + } + + resp, err := request.Get(downloadURL) + if err != nil { + return nil, err + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, newOpenAIImageStatusError(resp, "download image bytes failed") + } + return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes)) +} + +type openAIImageStatusError struct { + StatusCode int + Message string + ResponseBody []byte + ResponseHeaders http.Header + RequestID string + URL string +} + +func (e *openAIImageStatusError) Error() string { + if e == nil { + return "openai image backend request failed" + } + if e.Message != "" { + return e.Message + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode) + } + return "openai image backend request failed" +} + +func newOpenAIImageStatusError(resp *req.Response, fallback string) error { + if resp == nil { + if strings.TrimSpace(fallback) == "" { + fallback = "openai image backend request failed" + } + return fmt.Errorf("%s", fallback) + } + + statusCode := resp.StatusCode + headers := http.Header(nil) + requestID := "" + requestURL := "" + body := []byte(nil) + + if resp.Response != nil { + headers = resp.Header.Clone() + requestID = strings.TrimSpace(resp.Header.Get("x-request-id")) + if resp.Request != nil && resp.Request.URL != nil { + requestURL = resp.Request.URL.String() + } + if resp.Body != nil { + body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + } + } + + message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body)) + if message == "" { + prefix := strings.TrimSpace(fallback) + if prefix == "" { + prefix = "openai image backend request failed" + } + message = fmt.Sprintf("%s: status %d", prefix, statusCode) + } + + return &openAIImageStatusError{ + StatusCode: statusCode, + Message: message, + ResponseBody: body, + ResponseHeaders: headers, + RequestID: requestID, + URL: requestURL, + } +} + +func isOpenAIImageTransientConversationNotFoundError(err error) bool { + statusErr, ok := err.(*openAIImageStatusError) + if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(statusErr.Message)) + if strings.Contains(msg, "conversation_not_found") { + return true + } + if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") { + return true + } + bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody))) + if strings.Contains(bodyMsg, "conversation_not_found") { + return true + } + return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found") +} + +func cloneHTTPHeader(src http.Header) http.Header { + dst := make(http.Header, len(src)) + for key, values := range src { + copied := make([]string, len(values)) + copy(copied, values) + dst[key] = copied + } + return dst +} + +func headerToMap(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(header)) + for key, values := range header { + if len(values) == 0 { + continue + } + result[key] = values[0] + } + return result +} + +func dedupeStrings(values []string) []string { + if len(values) == 0 { + return nil + } + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, value := range values { + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go new file mode 100644 index 0000000000000000000000000000000000000000..64d995e138f7194b7ef11e02f085abc1e8d25733 --- /dev/null +++ b/backend/internal/service/openai_images_responses.go @@ -0,0 +1,853 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type openAIResponsesImageResult struct { + Result string + RevisedPrompt string + OutputFormat string + Size string + Background string + Quality string + Model string +} + +func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string { + if strings.TrimSpace(result.Result) != "" { + return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result) + } + return "item:" + strings.TrimSpace(itemID) +} + +func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool { + if results == nil { + return false + } + key := openAIResponsesImageResultKey(itemID, result) + if key != "" { + if _, exists := seen[key]; exists { + return false + } + seen[key] = struct{}{} + } + *results = append(*results, result) + return true +} + +func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) { + if dst == nil { + return + } + if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" { + dst.OutputFormat = trimmed + } + if trimmed := strings.TrimSpace(src.Size); trimmed != "" { + dst.Size = trimmed + } + if trimmed := strings.TrimSpace(src.Background); trimmed != "" { + dst.Background = trimmed + } + if trimmed := strings.TrimSpace(src.Quality); trimmed != "" { + dst.Quality = trimmed + } + if trimmed := strings.TrimSpace(src.Model); trimmed != "" { + dst.Model = trimmed + } +} + +func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) { + switch gjson.GetBytes(payload, "type").String() { + case "response.created", "response.in_progress", "response.completed": + default: + return openAIResponsesImageResult{}, 0, false + } + + response := gjson.GetBytes(payload, "response") + if !response.Exists() { + return openAIResponsesImageResult{}, 0, false + } + + meta := openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()), + Size: strings.TrimSpace(response.Get("tools.0.size").String()), + Background: strings.TrimSpace(response.Get("tools.0.background").String()), + Quality: strings.TrimSpace(response.Get("tools.0.quality").String()), + Model: strings.TrimSpace(response.Get("tools.0.model").String()), + } + return meta, response.Get("created_at").Int(), true +} + +func buildOpenAIImagesStreamPartialPayload( + eventType string, + b64 string, + partialImageIndex int64, + responseFormat string, + createdAt int64, + meta openAIResponsesImageResult, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex) + payload, _ = sjson.SetBytes(payload, "b64_json", b64) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64) + } + if meta.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", meta.Background) + } + if meta.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat) + } + if meta.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", meta.Quality) + } + if meta.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", meta.Size) + } + if meta.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", meta.Model) + } + return payload +} + +func buildOpenAIImagesStreamCompletedPayload( + eventType string, + img openAIResponsesImageResult, + responseFormat string, + createdAt int64, + usageRaw []byte, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } + if img.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", img.Background) + } + if img.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat) + } + if img.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", img.Quality) + } + if img.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", img.Size) + } + if img.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", img.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw) + } + return payload +} + +func openAIImageOutputMIMEType(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(strings.TrimSpace(outputFormat)) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + default: + return "image/png" + } +} + +func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) { + if len(upload.Data) == 0 { + return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName)) + } + contentType := strings.TrimSpace(upload.ContentType) + if contentType == "" { + contentType = http.DetectContentType(upload.Data) + } + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil +} + +func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) { + if parsed == nil { + return nil, fmt.Errorf("parsed images request is required") + } + prompt := strings.TrimSpace(parsed.Prompt) + if prompt == "" { + return nil, fmt.Errorf("prompt is required") + } + + inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads)) + for _, imageURL := range parsed.InputImageURLs { + if trimmed := strings.TrimSpace(imageURL); trimmed != "" { + inputImages = append(inputImages, trimmed) + } + } + for _, upload := range parsed.Uploads { + dataURL, err := openAIImageUploadToDataURL(upload) + if err != nil { + return nil, err + } + inputImages = append(inputImages, dataURL) + } + if parsed.IsEdits() && len(inputImages) == 0 { + return nil, fmt.Errorf("image input is required") + } + + req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`) + req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel) + + input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`) + input, _ = sjson.SetBytes(input, "0.content.0.text", prompt) + for index, imageURL := range inputImages { + part := []byte(`{"type":"input_image","image_url":""}`) + part, _ = sjson.SetBytes(part, "image_url", imageURL) + input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part) + } + req, _ = sjson.SetRawBytes(req, "input", input) + + action := "generate" + if parsed.IsEdits() { + action = "edit" + } + tool := []byte(`{"type":"image_generation","action":"","model":""}`) + tool, _ = sjson.SetBytes(tool, "action", action) + tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel)) + + for _, field := range []struct { + path string + value string + }{ + {path: "size", value: parsed.Size}, + {path: "quality", value: parsed.Quality}, + {path: "background", value: parsed.Background}, + {path: "output_format", value: parsed.OutputFormat}, + {path: "moderation", value: parsed.Moderation}, + {path: "style", value: parsed.Style}, + } { + if trimmed := strings.TrimSpace(field.value); trimmed != "" { + tool, _ = sjson.SetBytes(tool, field.path, trimmed) + } + } + if parsed.OutputCompression != nil { + tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression) + } + if parsed.PartialImages != nil { + tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages) + } + + maskImageURL := strings.TrimSpace(parsed.MaskImageURL) + if parsed.MaskUpload != nil { + dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload) + if err != nil { + return nil, err + } + maskImageURL = dataURL + } + if maskImageURL != "" { + tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL) + } + + req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`)) + req, _ = sjson.SetRawBytes(req, "tools.-1", tool) + return req, nil +} + +func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) { + if gjson.GetBytes(payload, "type").String() != "response.completed" { + return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type") + } + + createdAt := gjson.GetBytes(payload, "response.created_at").Int() + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + var ( + results []openAIResponsesImageResult + firstMeta openAIResponsesImageResult + ) + output := gjson.GetBytes(payload, "response.output") + if output.IsArray() { + for _, item := range output.Array() { + if item.Get("type").String() != "image_generation_call" { + continue + } + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + continue + } + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + if len(results) == 0 { + firstMeta = entry + } + results = append(results, entry) + } + } + + var usageRaw []byte + if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() { + usageRaw = []byte(usage.Raw) + } + return results, createdAt, usageRaw, firstMeta, nil +} + +func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) { + if gjson.GetBytes(payload, "type").String() != "response.output_item.done" { + return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type") + } + + item := gjson.GetBytes(payload, "item") + if !item.Exists() || item.Get("type").String() != "image_generation_call" { + return openAIResponsesImageResult{}, "", false, nil + } + + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + return openAIResponsesImageResult{}, "", false, nil + } + + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + return entry, strings.TrimSpace(item.Get("id").String()), true, nil +} + +func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) { + var ( + fallbackResults []openAIResponsesImageResult + fallbackSeen = make(map[string]struct{}) + createdAt int64 + usageRaw []byte + foundFinal bool + responseMeta openAIResponsesImageResult + ) + + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + payload := []byte(data) + if !gjson.ValidBytes(payload) { + continue + } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { + mergeOpenAIResponsesImageMeta(&responseMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + + switch gjson.GetBytes(payload, "type").String() { + case "response.output_item.done": + result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + if ok { + mergeOpenAIResponsesImageMeta(&result, responseMeta) + appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result) + } + case "response.completed": + results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + foundFinal = true + if completedAt > 0 { + createdAt = completedAt + } + if len(completedUsageRaw) > 0 { + usageRaw = completedUsageRaw + } + if len(results) > 0 { + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return results, createdAt, usageRaw, firstMeta, true, nil + } + if len(fallbackResults) > 0 { + firstMeta = fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, true, nil + } + } + } + + if len(fallbackResults) > 0 { + firstMeta := fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil + } + return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil +} + +func buildOpenAIImagesAPIResponse( + results []openAIResponsesImageResult, + createdAt int64, + usageRaw []byte, + firstMeta openAIResponsesImageResult, + responseFormat string, +) ([]byte, error) { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + out := []byte(`{"created":0,"data":[]}`) + out, _ = sjson.SetBytes(out, "created", createdAt) + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + for _, img := range results { + item := []byte(`{}`) + if format == "url" { + item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } else { + item, _ = sjson.SetBytes(item, "b64_json", img.Result) + } + if img.RevisedPrompt != "" { + item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt) + } + out, _ = sjson.SetRawBytes(out, "data.-1", item) + } + if firstMeta.Background != "" { + out, _ = sjson.SetBytes(out, "background", firstMeta.Background) + } + if firstMeta.OutputFormat != "" { + out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat) + } + if firstMeta.Quality != "" { + out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality) + } + if firstMeta.Size != "" { + out, _ = sjson.SetBytes(out, "size", firstMeta.Size) + } + if firstMeta.Model != "" { + out, _ = sjson.SetBytes(out, "model", firstMeta.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + out, _ = sjson.SetRawBytes(out, "usage", usageRaw) + } + return out, nil +} + +func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string { + if parsed != nil && parsed.IsEdits() { + return "image_edit" + } + return "image_generation" +} + +func buildOpenAIImagesStreamErrorBody(message string) []byte { + body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`) + if strings.TrimSpace(message) == "" { + message = "upstream request failed" + } + body, _ = sjson.SetBytes(body, "error.message", message) + return body +} + +func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error { + if strings.TrimSpace(eventName) != "" { + if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { + return err + } + } + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { + return err + } + flusher.Flush() + return nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( + resp *http.Response, + c *gin.Context, + responseFormat string, + fallbackModel string, +) (OpenAIUsage, int, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return OpenAIUsage{}, 0, err + } + + var usage OpenAIUsage + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + } + results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return OpenAIUsage{}, 0, err + } + if len(results) == 0 { + return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output") + } + if strings.TrimSpace(firstMeta.Model) == "" { + firstMeta.Model = strings.TrimSpace(fallbackModel) + } + + responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) + if err != nil { + return OpenAIUsage{}, 0, err + } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody) + return usage, len(results), nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( + resp *http.Response, + c *gin.Context, + startTime time.Time, + responseFormat string, + streamPrefix string, + fallbackModel string, +) (OpenAIUsage, int, *int, error) { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(resp.StatusCode) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + } + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + + reader := bufio.NewReader(resp.Body) + usage := OpenAIUsage{} + imageCount := 0 + var firstTokenMs *int + emitted := make(map[string]struct{}) + pendingResults := make([]openAIResponsesImageResult, 0, 1) + pendingSeen := make(map[string]struct{}) + streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} + var createdAt int64 + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + trimmedLine := strings.TrimRight(string(line), "\r\n") + data, ok := extractOpenAISSEDataLine(trimmedLine) + if ok && data != "" && data != "[DONE]" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + if gjson.ValidBytes(dataBytes) { + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + switch gjson.GetBytes(dataBytes, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) + if b64 != "" { + eventName := streamPrefix + ".partial_image" + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + } + case "response.output_item.done": + img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + if !ok { + break + } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey(itemID, img) + if _, exists := emitted[key]; exists { + break + } + if _, exists := pendingSeen[key]; exists { + break + } + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) + case "response.completed": + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { + err = fmt.Errorf("upstream did not return image output") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + eventName := streamPrefix + ".completed" + for _, img := range finalResults { + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + } + } + } + if err == io.EOF { + break + } + if err != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + } + + if imageCount > 0 { + return usage, imageCount, firstTokenMs, nil + } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + + streamErr := fmt.Errorf("stream disconnected before image generation completed") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, streamErr +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + if requestModel == "" { + requestModel = "gpt-image-2" + } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", + requestModel, + parsed.Endpoint, + account.Type, + len(parsed.Uploads), + ) + if parsed.N > 1 { + logger.LegacyPrintf( + "service.openai_gateway", + "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", + parsed.N, + requestModel, + parsed.Endpoint, + ) + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, responsesBody) + + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Accept", "text/event-stream") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleErrorResponse(ctx, resp, c, account, responsesBody) + } + defer func() { _ = resp.Body.Close() }() + + var ( + usage OpenAIUsage + imageCount int + firstTokenMs *int + ) + if parsed.Stream { + usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) + if err != nil { + return nil, err + } + } else { + usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) + if err != nil { + return nil, err + } + } + if imageCount <= 0 { + imageCount = parsed.N + } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go new file mode 100644 index 0000000000000000000000000000000000000000..200547d4989c99e31a51ccfe5cd40ab93aa1ed4d --- /dev/null +++ b/backend/internal/service/openai_images_test.go @@ -0,0 +1,718 @@ +package service + +import ( + "bytes" + "context" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/textproto" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "/v1/images/generations", parsed.Endpoint) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, "draw a cat", parsed.Prompt) + require.True(t, parsed.Stream) + require.Equal(t, "1024x1024", parsed.Size) + require.Equal(t, "1K", parsed.SizeTier) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) + require.False(t, parsed.Multipart) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background")) + require.NoError(t, writer.WriteField("size", "1536x1024")) + part, err := writer.CreateFormFile("image", "source.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake-image-bytes")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "/v1/images/edits", parsed.Endpoint) + require.True(t, parsed.Multipart) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, "replace background", parsed.Prompt) + require.Equal(t, "1536x1024", parsed.Size) + require.Equal(t, "2K", parsed.SizeTier) + require.Len(t, parsed.Uploads, 1) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace foreground")) + require.NoError(t, writer.WriteField("output_format", "png")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_compression", "80")) + require.NoError(t, writer.WriteField("partial_images", "2")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("source-image-bytes")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("mask-image-bytes")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Len(t, parsed.Uploads, 1) + require.NotNil(t, parsed.MaskUpload) + require.True(t, parsed.HasMask) + require.Equal(t, "png", parsed.OutputFormat) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 80, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"prompt":"draw a cat"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.Nil(t, parsed) + require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace the background", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "input_fidelity":"high", + "output_compression":90, + "partial_images":2, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs) + require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 90, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.True(t, parsed.HasMask) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + +func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) { + items := collectOpenAIImagePointers([]byte(`{ + "revised_prompt": "cat astronaut", + "parts": [ + {"b64_json":"QUJD"}, + {"download_url":"https://files.example.com/image.png?sig=1"}, + {"asset_pointer":"file-service://file_123"} + ] + }`)) + + require.Len(t, items, 3) + var sawBase64, sawURL, sawPointer bool + for _, item := range items { + if item.B64JSON == "QUJD" { + sawBase64 = true + require.Equal(t, "cat astronaut", item.Prompt) + } + if item.DownloadURL == "https://files.example.com/image.png?sig=1" { + sawURL = true + } + if item.Pointer == "file-service://file_123" { + sawPointer = true + } + } + require.True(t, sawBase64) + require.True(t, sawURL) + require.True(t, sawPointer) +} + +func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) { + data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{ + B64JSON: "data:image/png;base64,QUJD", + }) + require.NoError(t, err) + require.Equal(t, []byte("ABC"), data) +} + +func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic)) + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) +} + +type openAIImageTestSSEEvent struct { + Name string + Data string +} + +func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent { + chunks := strings.Split(body, "\n\n") + events := make([]openAIImageTestSSEEvent, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + var event openAIImageTestSSEEvent + for _, line := range strings.Split(chunk, "\n") { + switch { + case strings.HasPrefix(line, "event: "): + event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + case strings.HasPrefix(line, "data: "): + event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + } + } + if event.Name != "" || event.Data != "" { + events = append(events, event) + } + } + return events +} + +func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) { + for _, event := range events { + if event.Name == name { + return event, true + } + } + return openAIImageTestSSEEvent{}, false +} + +func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 42}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + "chatgpt_account_id": "acct-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-image-2", result.Model) + require.Equal(t, "gpt-image-2", result.UpstreamModel) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 22, result.Usage.OutputTokens) + require.Equal(t, 7, result.Usage.ImageOutputTokens) + + require.NotNil(t, upstream.lastReq) + require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String()) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) + require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta")) + + require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String()) + require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String()) + require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists()) + require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 2, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image") + require.True(t, ok) + require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background with aurora")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_format", "webp")) + require.NoError(t, writer.WriteField("quality", "high")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("png-image-content")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("png-mask-content")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 100}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_edit_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 3, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists()) + require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String()) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,")) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,")) + require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace background with aurora", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "stream":true, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 4, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String()) + require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String()) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image") + require.True(t, ok) + require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed") + require.True(t, ok) + require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: "gpt-image-2", + Prompt: "draw a cat", + N: 2, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) + require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) +} + +func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Model: "gpt-image-2", + Prompt: "replace background", + InputFidelity: "high", + InputImageURLs: []string{ + "https://example.com/source.png", + }, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists()) + require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String()) +} + +func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) { + body := []byte( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" + + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + ) + + results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body) + require.NoError(t, err) + require.True(t, foundFinal) + require.Equal(t, int64(1710000004), createdAt) + require.Len(t, results, 1) + require.Equal(t, "aGVsbG8=", results[0].Result) + require.Equal(t, "draw a cat", results[0].RevisedPrompt) + require.Equal(t, "png", firstMeta.OutputFormat) + require.JSONEq(t, `{"images":1}`, string(usageRaw)) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_output_item_done"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 5, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.NotContains(t, rec.Body.String(), "event: error") +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index cda7e3698aafc025d3bd6d4fb5ff8f043c46b07e..f25863a8809a07d5829b4e3b340774e17ef76adb 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) { } } -func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { +func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) { account := &Account{ Credentials: map[string]any{}, } withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) - if withoutDefault != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") + if withoutDefault != "gpt-5.4" { + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") } withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) @@ -87,10 +87,11 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * func TestNormalizeCodexModel(t *testing.T) { cases := map[string]string{ - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", + "gpt-image-2": "gpt-image-2", } for input, expected := range cases { @@ -111,7 +112,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { name: "oauth keeps codex normalization behavior", account: &Account{Type: AccountTypeOAuth}, model: "gemini-3-flash-preview", - want: "gpt-5.1", + want: "gpt-5.4", }, { name: "apikey preserves custom compatible model", diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 1cae6fe52fbc9f1437e64d514f45ae91cb4b4286..08a10a025c0e6b4b5d55bf6ddcce115ecbdbfa43 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -36,11 +36,15 @@ return 0 // - Scheduling: 5-field cron spec (minute hour dom month dow). // - Multi-instance: best-effort Redis leader lock so only one node runs cleanup. // - Safety: deletes in batches to avoid long transactions. +// +// 附带:在 runCleanupOnce 末尾调用 ChannelMonitorService.RunDailyMaintenance, +// 统一共享 cron schedule + leader lock + heartbeat,避免再引一套调度。 type OpsCleanupService struct { - opsRepo OpsRepository - db *sql.DB - redisClient *redis.Client - cfg *config.Config + opsRepo OpsRepository + db *sql.DB + redisClient *redis.Client + cfg *config.Config + channelMonitorSvc *ChannelMonitorService instanceID string @@ -57,13 +61,15 @@ func NewOpsCleanupService( db *sql.DB, redisClient *redis.Client, cfg *config.Config, + channelMonitorSvc *ChannelMonitorService, ) *OpsCleanupService { return &OpsCleanupService{ - opsRepo: opsRepo, - db: db, - redisClient: redisClient, - cfg: cfg, - instanceID: uuid.NewString(), + opsRepo: opsRepo, + db: db, + redisClient: redisClient, + cfg: cfg, + channelMonitorSvc: channelMonitorSvc, + instanceID: uuid.NewString(), } } @@ -248,6 +254,15 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet out.dailyPreagg = n } + // Channel monitor 每日维护(聚合昨日明细 + 软删过期明细/聚合)。 + // 失败只记日志,不影响 ops 清理的成功状态(与 ops 各步骤风格一致); + // 维护本身已经把每步错误打到 slog,heartbeat result 不再分项记录。 + if s.channelMonitorSvc != nil { + if err := s.channelMonitorSvc.RunDailyMaintenance(ctx); err != nil { + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] channel monitor maintenance failed: %v", err) + } + } + return out, nil } diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index c0e814ab7bf19fb719faee68cd04517d019005ea..bd40d389e273ea93e38c6e3f8debb0f02552dd8a 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta func detectOpsRetryType(path string) opsRetryRequestType { p := strings.ToLower(strings.TrimSpace(path)) switch { - case strings.Contains(p, "/responses"): + case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"): return opsRetryTypeOpenAI case strings.Contains(p, "/v1beta/"): return opsRetryTypeGeminiV1B diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index 569052788564c04c21fa164299e99c20832edb7b..973c601a06de6948c0aefe9534d354afdd7cb745 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -20,6 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return nil, fmt.Errorf("query provider instances: %w", err) } typeInstances := pcGroupByPaymentType(instances) + typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances) resp := &MethodLimitsResponse{ Methods: make(map[string]MethodLimits, len(typeInstances)), } @@ -31,6 +32,41 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return resp, nil } +func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { + if len(typeInstances) == 0 { + return typeInstances + } + + filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances)) + for paymentType, groupedInstances := range typeInstances { + filtered[paymentType] = groupedInstances + } + + for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { + matching := filterEnabledVisibleMethodInstances(instances, method) + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil { + delete(filtered, method) + continue + } + if providerKey == "" { + if len(matching) == 0 { + delete(filtered, method) + continue + } + filtered[method] = matching + continue + } + selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey) + if len(selectedInstances) == 0 { + delete(filtered, method) + continue + } + filtered[method] = selectedInstances + } + return filtered +} + // GetMethodLimits returns per-payment-type limits from enabled provider instances. func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go index 73ad66ef03c60127040d391e7dfb091ebced8dc9..4df506d675b6f7110ec0a2d47691a0ec2a52e144 100644 --- a/backend/internal/service/payment_config_limits_test.go +++ b/backend/internal/service/payment_config_limits_test.go @@ -1,10 +1,12 @@ package service import ( + "context" "testing" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" ) func TestUnionFloat(t *testing.T) { @@ -299,3 +301,161 @@ func TestPcInstanceTypeLimits(t *testing.T) { } }) } + +func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) { + tests := []struct { + name string + sourceSetting string + wantAlipaySingleMin float64 + wantAlipaySingleMax float64 + wantGlobalMin float64 + wantGlobalMax float64 + }{ + { + name: "official source", + sourceSetting: VisibleMethodSourceOfficialAlipay, + wantAlipaySingleMin: 10, + wantAlipaySingleMax: 100, + wantGlobalMin: 10, + wantGlobalMax: 300, + }, + { + name: "easypay source", + sourceSetting: VisibleMethodSourceEasyPayAlipay, + wantAlipaySingleMin: 20, + wantAlipaySingleMax: 200, + wantGlobalMin: 20, + wantGlobalMax: 300, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting, + }, + }, + } + + resp, err := svc.GetAvailableMethodLimits(ctx) + if err != nil { + t.Fatalf("GetAvailableMethodLimits returned error: %v", err) + } + + alipayLimits, ok := resp.Methods[payment.TypeAlipay] + if !ok { + t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods) + } + if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax { + t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax) + } + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + if !ok { + t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) + } + if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { + t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) + } + if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax { + t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax) + } + }) + } +} + +func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Mixed"). + SetConfig("{}"). + SetSupportedTypes("alipay,wxpay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}}, + } + + resp, err := svc.GetAvailableMethodLimits(ctx) + require.NoError(t, err) + + alipayLimits, ok := resp.Methods[payment.TypeAlipay] + require.True(t, ok, "expected alipay limits to remain visible") + require.Equal(t, 10.0, alipayLimits.SingleMin) + require.Equal(t, 200.0, alipayLimits.SingleMax) + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + require.True(t, ok, "expected wxpay limits to remain visible") + require.Equal(t, 30.0, wxpayLimits.SingleMin) + require.Equal(t, 400.0, wxpayLimits.SingleMax) + + require.Equal(t, 10.0, resp.GlobalMin) + require.Equal(t, 400.0, resp.GlobalMax) +} diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 3c406b4534c746319bc6b2e7e37d13b8ca1523fa..ff05e559a499e19bfe4e1b87c55a544a4eaf0fd2 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "strconv" "strings" @@ -11,9 +12,22 @@ import ( "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/payment/provider" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) +// validateProviderConfig runs the provider's constructor to surface config-level +// errors at save time (e.g. wxpay missing certSerial), instead of only failing +// when an order is created. Returns the structured ApplicationError from the +// constructor so the frontend i18n layer can localize it. +// +// Only validates enabled instances — a disabled instance may be a half-filled +// draft the admin will complete later. +func (s *PaymentConfigService) validateProviderConfig(providerKey string, config map[string]string) error { + _, err := provider.CreateProvider(providerKey, "_validate_", config) + return err +} + // --- Provider Instance CRUD --- func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) { @@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, - AllowUserRefund: inst.AllowUserRefund, - SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } - resp.Config, err = s.decryptAndMaskConfig(inst.Config) + resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config) if err != nil { return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) } @@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte return result, nil } -func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { - return s.decryptConfig(encrypted) +// decryptAndMaskConfig returns the stored config with sensitive fields omitted. +// Admin UIs display masked placeholders for these; the raw values never leave +// the server. Callers that need the full config (e.g. payment runtime) must +// use decryptConfig directly. +func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) { + cfg, err := s.decryptConfig(encrypted) + if err != nil { + return nil, err + } + if cfg == nil { + return nil, nil + } + masked := make(map[string]string, len(cfg)) + for k, v := range cfg { + if isSensitiveProviderConfigField(providerKey, k) { + continue + } + masked[k] = v + } + return masked, nil } // pendingOrderStatuses are order statuses considered "in progress". @@ -71,18 +102,62 @@ var pendingOrderStatuses = []string{ payment.OrderStatusRecharging, } -var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} +// providerSensitiveConfigFields is the authoritative list of config keys that +// are treated as secrets per provider. Must stay in sync with the frontend +// definition at frontend/src/components/payment/providerConfig.ts +// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true). +// +// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl, +// stripe publishableKey) are returned in plaintext by the admin GET API. +var providerSensitiveConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} + +// providerPendingOrderProtectedConfigFields lists config keys that cannot be +// changed while the instance has in-progress orders. This includes secrets plus +// all provider identity fields that are snapshotted into orders or used by +// webhook/refund verification. +var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}, "pid": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} + +func isSensitiveProviderConfigField(providerKey, fieldName string) bool { + fields, ok := providerSensitiveConfigFields[providerKey] + if !ok { + return false + } + _, found := fields[strings.ToLower(fieldName)] + return found +} -func isSensitiveConfigField(fieldName string) bool { - lower := strings.ToLower(fieldName) - for _, p := range sensitiveConfigPatterns { - if strings.Contains(lower, p) { +func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool { + fields, ok := providerPendingOrderProtectedConfigFields[providerKey] + if !ok { + return false + } + for fieldName := range fields { + if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) { return true } } return false } +func providerConfigFieldValue(config map[string]string, fieldName string) string { + for key, value := range config { + if strings.EqualFold(key, fieldName) { + return value + } + } + return "" +} + func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { return s.entClient.PaymentOrder.Query(). Where( @@ -108,6 +183,14 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil { return nil, err } + if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil { + return nil, err + } + if req.Enabled { + if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil { + return nil, err + } + } enc, err := s.encryptConfig(req.Config) if err != nil { return nil, err @@ -136,18 +219,47 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { // NOTE: This function exceeds 30 lines due to per-field nil-check patch update // boilerplate and pending-order safety checks. func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + current, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("load provider instance: %w", err) + } + var pendingOrderCount *int + getPendingOrderCount := func() (int, error) { + if pendingOrderCount != nil { + return *pendingOrderCount, nil + } + count, err := s.countPendingOrders(ctx, id) + if err != nil { + return 0, fmt.Errorf("check pending orders: %w", err) + } + pendingOrderCount = &count + return count, nil + } + nextEnabled := current.Enabled + if req.Enabled != nil { + nextEnabled = *req.Enabled + } + nextSupportedTypes := current.SupportedTypes + if req.SupportedTypes != nil { + nextSupportedTypes = joinTypes(req.SupportedTypes) + } + if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil { + return nil, err + } + var mergedConfig map[string]string if req.Config != nil { - hasSensitive := false - for k := range req.Config { - if isSensitiveConfigField(k) && req.Config[k] != "" { - hasSensitive = true - break - } + currentConfig, err := s.decryptConfig(current.Config) + if err != nil { + return nil, fmt.Errorf("decrypt existing config: %w", err) + } + mergedConfig, err = s.mergeConfig(ctx, id, req.Config) + if err != nil { + return nil, err } - if hasSensitive { - count, err := s.countPendingOrders(ctx, id) + if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) { + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders"). @@ -156,25 +268,40 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } } if req.Enabled != nil && !*req.Enabled { - count, err := s.countPendingOrders(ctx, id) + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders"). WithMetadata(map[string]string{"count": strconv.Itoa(count)}) } } + // Validate merged config when the instance will end up enabled. + // This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time, + // so admins see them in the dialog instead of only when an order is created. + finalEnabled := current.Enabled + if req.Enabled != nil { + finalEnabled = *req.Enabled + } + if finalEnabled { + configToValidate := mergedConfig + if configToValidate == nil { + configToValidate, err = s.decryptConfig(current.Config) + if err != nil { + return nil, fmt.Errorf("decrypt existing config: %w", err) + } + } + if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil { + return nil, err + } + } u := s.entClient.PaymentProviderInstance.UpdateOneID(id) if req.Name != nil { u.SetName(*req.Name) } - if req.Config != nil { - merged, err := s.mergeConfig(ctx, id, req.Config) - if err != nil { - return nil, err - } - enc, err := s.encryptConfig(merged) + if mergedConfig != nil { + enc, err := s.encryptConfig(mergedConfig) if err != nil { return nil, err } @@ -182,17 +309,13 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.SupportedTypes != nil { // Check pending orders before removing payment types - count, err := s.countPendingOrders(ctx, id) + count, err := getPendingOrderCount() if err != nil { - return nil, fmt.Errorf("check pending orders: %w", err) + return nil, err } if count > 0 { // Load current instance to compare types - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err != nil { - return nil, fmt.Errorf("load provider instance: %w", err) - } - oldTypes := strings.Split(inst.SupportedTypes, ",") + oldTypes := strings.Split(current.SupportedTypes, ",") newTypes := req.SupportedTypes for _, ot := range oldTypes { ot = strings.TrimSpace(ot) @@ -237,10 +360,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in if req.RefundEnabled != nil { refundEnabled = *req.RefundEnabled } else { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil { - refundEnabled = inst.RefundEnabled - } + refundEnabled = current.RefundEnabled } if refundEnabled { u.SetAllowUserRefund(true) @@ -282,27 +402,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) } if existing == nil { - return newConfig, nil + existing = map[string]string{} } for k, v := range newConfig { + // Preserve existing secrets when the client submits an empty value + // (admin UI omits the value to indicate "leave unchanged"). + if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { + continue + } existing[k] = v } return existing, nil } -func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { - if encrypted == "" { +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext +// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including +// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, +// letting the admin re-enter the config via the UI to complete the migration. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional +// shim for pre-plaintext records. Remove it (and the encryptionKey field) after +// a few releases once all live deployments have re-saved their provider configs. +func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { return nil, nil } - decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) - if err != nil { - return nil, fmt.Errorf("decrypt config: %w", err) + var cfg map[string]string + if err := json.Unmarshal([]byte(stored), &cfg); err == nil { + return cfg, nil } - var raw map[string]string - if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { - return nil, fmt.Errorf("unmarshal decrypted config: %w", err) + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. + if len(s.encryptionKey) == payment.AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal + if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { + return cfg, nil + } + } } - return raw, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { @@ -317,14 +458,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) } +// encryptConfig serialises a provider config for storage. +// New records are written as plaintext JSON; the historical AES-GCM wrapping +// has been dropped but decryptConfig still accepts old ciphertext during migration. func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { data, err := json.Marshal(cfg) if err != nil { return "", fmt.Errorf("marshal config: %w", err) } - enc, err := payment.Encrypt(string(data), s.encryptionKey) - if err != nil { - return "", fmt.Errorf("encrypt config: %w", err) - } - return enc, nil + return string(data), nil } diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2aaa874f951b677eb7d65c61bf87d813332f5747..e0d2908a71027343721d82e6c353250a34a1aafb 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -3,8 +3,18 @@ package service import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "strconv" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -97,41 +107,52 @@ func TestValidateProviderRequest(t *testing.T) { } } -func TestIsSensitiveConfigField(t *testing.T) { +func TestIsSensitiveProviderConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string - wantSen bool + providerKey string + field string + wantSen bool }{ - // Sensitive fields (contain key/secret/private/password/pkey patterns) - {"secretKey", true}, - {"apiSecret", true}, - {"pkey", true}, - {"privateKey", true}, - {"apiPassword", true}, - {"appKey", true}, - {"SECRET_TOKEN", true}, - {"PrivateData", true}, - {"PASSWORD", true}, - {"mySecretValue", true}, - - // Non-sensitive fields - {"appId", false}, - {"mchId", false}, - {"apiBase", false}, - {"endpoint", false}, - {"merchantNo", false}, - {"paymentMode", false}, - {"notifyUrl", false}, + // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets + {"stripe", "secretKey", true}, + {"stripe", "webhookSecret", true}, + {"stripe", "SecretKey", true}, // case-insensitive + {"stripe", "publishableKey", false}, + {"stripe", "appId", false}, + + // Alipay + {"alipay", "privateKey", true}, + {"alipay", "publicKey", true}, + {"alipay", "alipayPublicKey", true}, + {"alipay", "appId", false}, + {"alipay", "notifyUrl", false}, + + // Wxpay + {"wxpay", "privateKey", true}, + {"wxpay", "apiV3Key", true}, + {"wxpay", "publicKey", true}, + {"wxpay", "publicKeyId", false}, + {"wxpay", "certSerial", false}, + {"wxpay", "mchId", false}, + + // EasyPay + {"easypay", "pkey", true}, + {"easypay", "pid", false}, + {"easypay", "apiBase", false}, + + // Unknown provider: never sensitive + {"unknown", "secretKey", false}, } for _, tc := range tests { - t.Run(tc.field, func(t *testing.T) { + tc := tc + t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) { t.Parallel() - got := isSensitiveConfigField(tc.field) - assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) + got := isSensitiveProviderConfigField(tc.providerKey, tc.field) + assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field) }) } } @@ -185,3 +206,403 @@ func TestJoinTypes(t *testing.T) { }) } } + +func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + _, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay Alipay", + Config: map[string]string{ + "pid": "1001", + "pkey": "pkey-1001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"alipay"}, + Enabled: true, + }) + require.NoError(t, err) + + _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "alipay", + Name: "Official Alipay", + Config: map[string]string{"appId": "app-1", "privateKey": "private-key"}, + SupportedTypes: []string{"alipay"}, + Enabled: true, + }) + require.NoError(t, err) +} + +func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay WeChat", + Config: map[string]string{ + "pid": "2001", + "pkey": "pkey-2001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"wxpay"}, + Enabled: true, + }) + require.NoError(t, err) + require.NotNil(t, existing) + + candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "wxpay", + Name: "Official WeChat", + Config: validWxpayProviderConfig(t), + SupportedTypes: []string{"wxpay"}, + Enabled: false, + }) + require.NoError(t, err) + + _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ + Enabled: boolPtrValue(true), + }) + require.NoError(t, err) +} + +func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: "easypay", + Name: "EasyPay", + Config: map[string]string{ + "pid": "3001", + "pkey": "pkey-3001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, + SupportedTypes: []string{"alipay"}, + Enabled: false, + }) + require.NoError(t, err) + + _, err = svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Enabled: boolPtrValue(true), + SupportedTypes: []string{"alipay", "wxpay"}, + }) + require.NoError(t, err) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + require.True(t, saved.Enabled) + require.Equal(t, "alipay,wxpay", saved.SupportedTypes) +} + +func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerKey string + createConfig func(*testing.T) map[string]string + supportedType []string + updateConfig map[string]string + fieldName string + wantValue string + }{ + { + name: "wxpay appId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"appId": "wx-app-updated"}, + fieldName: "appId", + wantValue: "wx-app-test", + }, + { + name: "wxpay mpAppId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfigWithJSAPIAppID, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"}, + fieldName: "mpAppId", + wantValue: "wx-mp-app-test", + }, + { + name: "wxpay mchId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"mchId": "mch-updated"}, + fieldName: "mchId", + wantValue: "mch-test", + }, + { + name: "wxpay publicKeyId", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"}, + fieldName: "publicKeyId", + wantValue: "public-key-id-test", + }, + { + name: "wxpay certSerial", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"certSerial": "cert-serial-updated"}, + fieldName: "certSerial", + wantValue: "cert-serial-test", + }, + { + name: "alipay appId", + providerKey: payment.TypeAlipay, + createConfig: validAlipayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"appId": "alipay-app-updated"}, + fieldName: "appId", + wantValue: "alipay-app-test", + }, + { + name: "easypay pid", + providerKey: payment.TypeEasyPay, + createConfig: validEasyPayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"pid": "pid-updated"}, + fieldName: "pid", + wantValue: "pid-test", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: tc.providerKey, + Name: "protected-config-instance", + Config: tc.createConfig(t), + SupportedTypes: tc.supportedType, + Enabled: true, + }) + require.NoError(t, err) + + createPendingProviderConfigOrder(t, ctx, client, instance) + + updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Config: tc.updateConfig, + }) + require.Nil(t, updated) + require.Error(t, err) + require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err)) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + cfg, err := svc.decryptConfig(saved.Config) + require.NoError(t, err) + require.Equal(t, tc.wantValue, cfg[tc.fieldName]) + }) + } +} + +func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerKey string + createConfig func(*testing.T) map[string]string + supportedType []string + updateConfig map[string]string + fieldName string + wantValue string + }{ + { + name: "wxpay notifyUrl", + providerKey: payment.TypeWxpay, + createConfig: validWxpayProviderConfig, + supportedType: []string{payment.TypeWxpay}, + updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"}, + fieldName: "notifyUrl", + wantValue: "https://merchant.example.com/wxpay/notify-v2", + }, + { + name: "alipay same appId", + providerKey: payment.TypeAlipay, + createConfig: validAlipayProviderConfig, + supportedType: []string{payment.TypeAlipay}, + updateConfig: map[string]string{"appId": "alipay-app-test"}, + fieldName: "appId", + wantValue: "alipay-app-test", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + svc := &PaymentConfigService{ + entClient: client, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + } + + instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ + ProviderKey: tc.providerKey, + Name: "safe-config-instance", + Config: tc.createConfig(t), + SupportedTypes: tc.supportedType, + Enabled: true, + }) + require.NoError(t, err) + + createPendingProviderConfigOrder(t, ctx, client, instance) + + updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{ + Config: tc.updateConfig, + }) + require.NoError(t, err) + require.NotNil(t, updated) + + saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID) + require.NoError(t, err) + cfg, err := svc.decryptConfig(saved.Config) + require.NoError(t, err) + require.Equal(t, tc.wantValue, cfg[tc.fieldName]) + }) + } +} + +func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) { + t.Helper() + + user, err := client.User.Create(). + SetEmail("provider-config-pending@example.com"). + SetPasswordHash("hash"). + SetUsername("provider-config-pending-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := strconv.FormatInt(instance.ID, 10) + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID). + SetOutTradeNo("sub2_pending_provider_config_" + instanceID). + SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(instance.ProviderKey). + Save(ctx) + require.NoError(t, err) +} + +func providerPendingOrderPaymentType(providerKey string) string { + switch providerKey { + case payment.TypeWxpay: + return payment.TypeWxpay + case payment.TypeAlipay: + return payment.TypeAlipay + default: + return payment.TypeAlipay + } +} + +func boolPtrValue(v bool) *bool { + return &v +} + +func validAlipayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + return map[string]string{ + "appId": "alipay-app-test", + "privateKey": "alipay-private-key-test", + "notifyUrl": "https://merchant.example.com/alipay/notify", + "returnUrl": "https://merchant.example.com/alipay/return", + } +} + +func validEasyPayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + return map[string]string{ + "pid": "pid-test", + "pkey": "pkey-test", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/easypay/notify", + "returnUrl": "https://merchant.example.com/easypay/return", + } +} + +func validWxpayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return map[string]string{ + "appId": "wx-app-test", + "mchId": "mch-test", + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": "12345678901234567890123456789012", + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-test", + "certSerial": "cert-serial-test", + } +} + +func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string { + t.Helper() + + cfg := validWxpayProviderConfig(t) + cfg["mpAppId"] = "wx-mp-app-test" + return cfg +} diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 59764b298cc8cffc55af8303e4f1af3e3a56c4eb..02d061aeeaad72263e0bb6a87eef02dc2ff73bde 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct { CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` + + VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` } // MethodLimits holds per-payment-type limits. @@ -196,6 +201,8 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingHelpImageURL, SettingHelpText, SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, + SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource, + SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { @@ -234,18 +241,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { + types := make([]string, 0, len(strings.Split(raw, ","))) for _, t := range strings.Split(raw, ",") { t = strings.TrimSpace(t) if t != "" { - cfg.EnabledTypes = append(cfg.EnabledTypes, t) + types = append(types, t) } } + cfg.EnabledTypes = NormalizeVisibleMethods(types) } return cfg } // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + if s.entClient == nil { + return "" + } instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.EnabledEQ(true), @@ -282,25 +294,29 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda } } m := map[string]string{ - SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), - SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), - SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), - SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), - SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), - SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), - SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), - SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier), - SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate), - SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), - SettingProductNamePrefix: derefStr(req.ProductNamePrefix), - SettingProductNameSuffix: derefStr(req.ProductNameSuffix), - SettingHelpImageURL: derefStr(req.HelpImageURL), - SettingHelpText: derefStr(req.HelpText), - SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), - SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), - SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), - SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), - SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), + SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), + SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), + SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), + SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), + SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), + SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), + SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier), + SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate), + SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), + SettingProductNamePrefix: derefStr(req.ProductNamePrefix), + SettingProductNameSuffix: derefStr(req.ProductNameSuffix), + SettingHelpImageURL: derefStr(req.HelpImageURL), + SettingHelpText: derefStr(req.HelpText), + SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), + SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), + SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), + SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), + SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource), + SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource), + SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled), + SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled), } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") @@ -385,3 +401,79 @@ func pcParseInt(s string, defaultVal int) int { } return v } + +func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool { + available := make(map[string]bool, 4) + for _, inst := range instances { + switch inst.ProviderKey { + case payment.TypeAlipay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) { + available[VisibleMethodSourceOfficialAlipay] = true + } + case payment.TypeWxpay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) { + available[VisibleMethodSourceOfficialWechat] = true + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(inst.SupportedTypes) { + switch NormalizeVisibleMethod(supportedType) { + case payment.TypeAlipay: + available[VisibleMethodSourceEasyPayAlipay] = true + case payment.TypeWxpay: + available[VisibleMethodSourceEasyPayWechat] = true + } + } + } + } + return available +} + +func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string { + shouldExpose := map[string]bool{ + payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available), + payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available), + } + + seen := make(map[string]struct{}, len(base)+2) + out := make([]string, 0, len(base)+2) + appendType := func(paymentType string) { + paymentType = NormalizeVisibleMethod(paymentType) + if paymentType == "" { + return + } + if _, ok := seen[paymentType]; ok { + return + } + seen[paymentType] = struct{}{} + out = append(out, paymentType) + } + + for _, paymentType := range base { + visibleMethod := NormalizeVisibleMethod(paymentType) + switch visibleMethod { + case payment.TypeAlipay, payment.TypeWxpay: + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + default: + appendType(visibleMethod) + } + } + + for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + } + return out +} + +func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool { + enabledKey := visibleMethodEnabledSettingKey(method) + sourceKey := visibleMethodSourceSettingKey(method) + if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" { + return false + } + source := NormalizeVisibleMethodSource(method, vals[sourceKey]) + return source != "" && available[source] +} diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 027bb796fcde5f6dbae4d46c7dfc3ee7d2c1e99e..f04f4697b1d64b016e5b85e83219e58d1ce01445 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -1,9 +1,19 @@ package service import ( + "context" + "database/sql" + "fmt" + "strings" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestPcParseFloat(t *testing.T) { @@ -163,6 +173,20 @@ func TestParsePaymentConfig(t *testing.T) { } }) + t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) { + t.Parallel() + vals := map[string]string{ + SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay", + } + cfg := svc.parsePaymentConfig(vals) + if len(cfg.EnabledTypes) != 2 { + t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes)) + } + if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" { + t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes) + } + }) + t.Run("empty enabled types string", func(t *testing.T) { t.Parallel() vals := map[string]string{ @@ -204,3 +228,210 @@ func TestGetBasePaymentType(t *testing.T) { }) } } + +func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) { + t.Parallel() + + base := []string{"alipay", "wxpay", "stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + } + available := map[string]bool{ + VisibleMethodSourceOfficialAlipay: true, + VisibleMethodSourceOfficialWechat: false, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"alipay", "stripe"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) { + t.Parallel() + + base := []string{"stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay, + } + available := map[string]bool{ + VisibleMethodSourceEasyPayAlipay: true, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"stripe", "alipay"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestBuildVisibleMethodSourceAvailability(t *testing.T) { + t.Parallel() + + instances := []*dbent.PaymentProviderInstance{ + {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"}, + {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"}, + {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"}, + } + + got := buildVisibleMethodSourceAvailability(instances) + if !got[VisibleMethodSourceOfficialAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay) + } + if !got[VisibleMethodSourceEasyPayAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay) + } + if !got[VisibleMethodSourceOfficialWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat) + } + if !got[VisibleMethodSourceEasyPayWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat) + } +} + +func TestGetPaymentConfigKeepsStoredEnabledTypes(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingEnabledPaymentTypes: "alipay,wxpay,stripe", + }, + }, + } + + cfg, err := svc.GetPaymentConfig(ctx) + if err != nil { + t.Fatalf("GetPaymentConfig returned error: %v", err) + } + + want := []string{payment.TypeAlipay, payment.TypeWxpay, payment.TypeStripe} + if len(cfg.EnabledTypes) != len(want) { + t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes) + } + for i := range want { + if cfg.EnabledTypes[i] != want[i] { + t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes) + } + } +} + +func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { + t.Helper() + + dbName := fmt.Sprintf( + "file:%s?mode=memory&cache=shared", + strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()), + ) + db, err := sql.Open("sqlite", dbName) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatalf("enable foreign keys: %v", err) + } + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +type paymentConfigSettingRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, nil +} +func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error { + s.updates = make(map[string]string, len(values)) + for key, value := range values { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} +func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } + +func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) { + repo := &paymentConfigSettingRepoStub{values: map[string]string{}} + svc := &PaymentConfigService{settingRepo: repo} + + alipayEnabled := true + wxpayEnabled := false + err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{ + VisibleMethodAlipayEnabled: &alipayEnabled, + VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay), + VisibleMethodWxpayEnabled: &wxpayEnabled, + VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat), + }) + if err != nil { + t.Fatalf("UpdatePaymentConfig returned error: %v", err) + } + + if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" { + t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay { + t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay) + } + if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" { + t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat { + t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat) + } +} + +func paymentConfigStrPtr(value string) *string { + return &value +} diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 44818b37e6ff50d7947ad018d5a85eeb9a38e204..243edff32254ce666f7a5e90b992c287cecf3ee1 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log/slog" "math" @@ -16,6 +17,14 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) +// ErrOrderNotFound is returned by HandlePaymentNotification when the webhook +// references an out_trade_no that does not exist in our DB. Callers (webhook +// handlers) should treat this as a terminal, non-retryable condition and still +// respond with a 2xx success to the provider — otherwise the provider will keep +// retrying forever (e.g. when a foreign environment's webhook endpoint is +// misconfigured to point at us, or when our orders table has been wiped). +var ErrOrderNotFound = errors.New("payment order not found") + // --- Payment Notification & Fulfillment --- func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error { @@ -25,37 +34,102 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme // Look up order by out_trade_no (the external order ID we sent to the provider) order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) if err != nil { - // Fallback: try legacy format (sub2_N where N is DB ID) - trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) - if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + // Fallback only for true legacy "sub2_N" DB-ID payloads when the + // current out_trade_no lookup genuinely did not find an order. + if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok { + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata) + } + if dbent.IsNotFound(err) { + return fmt.Errorf("%w: out_trade_no=%s", ErrOrderNotFound, n.OrderID) } - return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) + return fmt.Errorf("lookup order failed for out_trade_no %s: %w", n.OrderID, err) } - return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata) } -func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { +func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) { + if !dbent.IsNotFound(lookupErr) { + return 0, false + } + orderID = strings.TrimSpace(orderID) + if !strings.HasPrefix(orderID, orderIDPrefix) { + return 0, false + } + trimmed := strings.TrimPrefix(orderID, orderIDPrefix) + if trimmed == "" || trimmed == orderID { + return 0, false + } + oid, err := strconv.ParseInt(trimmed, 10, 64) + if err != nil || oid <= 0 { + return 0, false + } + return oid, true +} + +func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error { o, err := s.entClient.PaymentOrder.Get(ctx, oid) if err != nil { slog.Error("order not found", "orderID", oid) return nil } - // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). - // Also skip if paid is NaN/Inf (malformed provider data). - if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { - if math.Abs(paid-o.PayAmount) > amountToleranceCNY { - s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) - return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) - } + instanceProviderKey := "" + if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { + instanceProviderKey = inst.ProviderKey + } + expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey) + if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ + "expectedProvider": expectedProviderKey, + "actualProvider": pk, + "tradeNo": tradeNo, + }) + return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) + } + if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{ + "detail": err.Error(), + "tradeNo": tradeNo, + }) + return err + } + if !isValidProviderAmount(paid) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{ + "expected": o.PayAmount, + "paid": paid, + "tradeNo": tradeNo, + }) + return fmt.Errorf("invalid paid amount from provider: %v", paid) } - // Use order's expected amount when provider didn't report one - if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) { - paid = o.PayAmount + if math.Abs(paid-o.PayAmount) > amountToleranceCNY { + s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) + return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) } return s.toPaid(ctx, o, tradeNo, paid, pk) } +func isValidProviderAmount(amount float64) bool { + return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0) +} + +func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { + return validateProviderSnapshotMetadata(order, providerKey, metadata) +} + +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { + if key := strings.TrimSpace(instanceProviderKey); key != "" { + return key + } + if key := strings.TrimSpace(orderProviderKey); key != "" { + return key + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(orderPaymentType) +} + func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error { previousStatus := o.Status now := time.Now() diff --git a/backend/internal/service/payment_fulfillment_order_not_found_test.go b/backend/internal/service/payment_fulfillment_order_not_found_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f6787e29e0fd69888e861d7065ddbe4a5fb76b39 --- /dev/null +++ b/backend/internal/service/payment_fulfillment_order_not_found_test.go @@ -0,0 +1,106 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "errors" + "testing" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +// newOrderNotFoundTestClient wires an in-memory sqlite-backed ent.Client so +// tests can exercise HandlePaymentNotification's real DB lookup path without +// standing up a service stack. +func newOrderNotFoundTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_order_not_found?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +// TestHandlePaymentNotification_UnknownOrder_ReturnsSentinel exercises the +// happy-path of the webhook 404 fix: when the notification references an +// out_trade_no that does not exist in our DB, HandlePaymentNotification must +// return an error that errors.Is(err, ErrOrderNotFound) recognizes. The +// webhook handler relies on that contract to ack with a 2xx so the provider +// stops retrying. +func TestHandlePaymentNotification_UnknownOrder_ReturnsSentinel(t *testing.T) { + ctx := context.Background() + client := newOrderNotFoundTestClient(t) + + svc := &PaymentService{ + entClient: client, + providersLoaded: true, + } + + notification := &payment.PaymentNotification{ + OrderID: "sub2_does_not_exist_12345", + TradeNo: "stripe_evt_test_xyz", + Status: payment.NotificationStatusSuccess, + Amount: 1000, + } + + err := svc.HandlePaymentNotification(ctx, notification, payment.TypeStripe) + require.Error(t, err, "unknown out_trade_no should surface an error") + require.ErrorIs(t, err, ErrOrderNotFound, + "webhook handler relies on errors.Is(err, ErrOrderNotFound) to downgrade to 200") + + // Sanity: the wrapped error message should still include the out_trade_no + // for operator diagnostics. + require.Contains(t, err.Error(), notification.OrderID) +} + +// TestHandlePaymentNotification_NonSuccessStatus_Skips documents the +// short-circuit that precedes the DB lookup: when the notification is not a +// success event (e.g. Stripe non-payment events that reach us via the webhook +// route), we return nil without touching the DB and the handler responds 200. +func TestHandlePaymentNotification_NonSuccessStatus_Skips(t *testing.T) { + ctx := context.Background() + client := newOrderNotFoundTestClient(t) + + svc := &PaymentService{ + entClient: client, + providersLoaded: true, + } + + notification := &payment.PaymentNotification{ + OrderID: "sub2_does_not_exist_12345", + Status: "failed", // any value other than NotificationStatusSuccess + } + + err := svc.HandlePaymentNotification(ctx, notification, payment.TypeStripe) + require.NoError(t, err, + "non-success notifications must short-circuit before the DB lookup") +} + +// TestErrOrderNotFound_DistinctFromOtherErrors guards against an accidental +// collapse where a generic wrapped error would start matching ErrOrderNotFound +// (which would silently mask real DB failures). +func TestErrOrderNotFound_DistinctFromOtherErrors(t *testing.T) { + genericErr := errors.New("some other failure") + require.False(t, errors.Is(genericErr, ErrOrderNotFound)) + require.False(t, errors.Is(ErrOrderNotFound, genericErr)) + + wrappedLookupErr := errors.New("lookup order failed for out_trade_no sub2_42: connection refused") + require.False(t, errors.Is(wrappedLookupErr, ErrOrderNotFound), + "DB connection failures must not masquerade as order-not-found") +} diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 625b0d9f0ed9bed88489abc63bd7b5a5c24b6eb7..abdb59deacaa34f9c1b638791ac5ec33ea58745f 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -3,12 +3,39 @@ package service import ( + "context" "errors" + "math" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) +type paymentFulfillmentTestProvider struct { + key string + supportedTypes []payment.PaymentType +} + +func (p paymentFulfillmentTestProvider) Name() string { return p.key } +func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key } +func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType { + return p.supportedTypes +} +func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -161,3 +188,181 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) { assert.True(t, unusedCode.CanUse()) assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil)) } + +func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay), + ) +} + +func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeEasyPay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { + t.Parallel() + + assert.Equal(t, + payment.TypeWxpay, + expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""), + ) +} + +func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_key": payment.TypeEasyPay, + }, + } + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKeyForOrder(registry, order, ""), + ) +} + +func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "merchant_app_id": "wx-app-expected", + "merchant_id": "mch-expected", + "currency": "CNY", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-other", + "mchid": "mch-expected", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.ErrorContains(t, err, "wxpay appid mismatch") +} + +func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "9", + "provider_key": payment.TypeWxpay, + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{ + "appid": "wx-app-runtime", + "mchid": "mch-runtime", + "currency": "CNY", + "trade_state": "SUCCESS", + }) + assert.NoError(t, err) +} + +func TestParseLegacyPaymentOrderID(t *testing.T) { + t.Parallel() + + oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{}) + assert.True(t, ok) + assert.EqualValues(t, 42, oid) + + _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{}) + assert.False(t, ok) + + _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down")) + assert.False(t, ok) +} + +func TestIsValidProviderAmount(t *testing.T) { + t.Parallel() + + assert.True(t, isValidProviderAmount(0.01)) + assert.False(t, isValidProviderAmount(0)) + assert.False(t, isValidProviderAmount(-1)) + assert.False(t, isValidProviderAmount(math.NaN())) + assert.False(t, isValidProviderAmount(math.Inf(1))) +} + +func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "merchant_app_id": "alipay-app-expected", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{ + "app_id": "alipay-app-other", + }) + assert.ErrorContains(t, err, "alipay app_id mismatch") +} + +func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "merchant_id": "pid-expected", + }, + } + + err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{ + "pid": "pid-other", + }) + assert.ErrorContains(t, err, "easypay pid mismatch") +} diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 128416e406a60487172b55a50727968ea2dc229d..15d4509d4a79febddcafa6eb9e72a402c2c99676 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -2,9 +2,11 @@ package service import ( "context" + "errors" "fmt" "log/slog" "math" + "net/url" "strconv" "strings" "time" @@ -22,6 +24,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest if req.OrderType == "" { req.OrderType = payment.OrderTypeBalance } + if normalized := NormalizeVisibleMethod(req.PaymentType); normalized != "" { + req.PaymentType = normalized + } cfg, err := s.configService.GetPaymentConfig(ctx) if err != nil { return nil, fmt.Errorf("get payment config: %w", err) @@ -54,11 +59,25 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest feeRate := cfg.RechargeFeeRate payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate) payAmount, _ := strconv.ParseFloat(payAmountStr, 64) - order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount) + sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount) + if err != nil { + return nil, err + } + if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil { + return nil, err + } + oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel) + if err != nil { + return nil, err + } + if oauthResp != nil { + return oauthResp, nil + } + order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel) if err != nil { return nil, err } - resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan) + resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan, sel) if err != nil { _, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID). SetStatus(OrderStatusFailed). @@ -103,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe return plan, nil } -func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) { +func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) { tx, err := s.entClient.Tx(ctx) if err != nil { return nil, fmt.Errorf("begin transaction: %w", err) @@ -120,6 +139,17 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq tm = defaultOrderTimeoutMin } exp := time.Now().Add(time.Duration(tm) * time.Minute) + outTradeNo, err := s.allocateOutTradeNo(ctx, tx) + if err != nil { + return nil, err + } + providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req) + selectedInstanceID := "" + selectedProviderKey := "" + if sel != nil { + selectedInstanceID = strings.TrimSpace(sel.InstanceID) + selectedProviderKey = strings.TrimSpace(sel.ProviderKey) + } b := tx.PaymentOrder.Create(). SetUserID(req.UserID). SetUserEmail(user.Email). @@ -129,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq SetPayAmount(payAmount). SetFeeRate(feeRate). SetRechargeCode(""). - SetOutTradeNo(generateOutTradeNo()). + SetOutTradeNo(outTradeNo). SetPaymentType(req.PaymentType). SetPaymentTradeNo(""). SetOrderType(req.OrderType). @@ -140,6 +170,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq if req.SrcURL != "" { b.SetSrcURL(req.SrcURL) } + if selectedInstanceID != "" { + b.SetProviderInstanceID(selectedInstanceID) + } + if selectedProviderKey != "" { + b.SetProviderKey(selectedProviderKey) + } + if providerSnapshot != nil { + b.SetProviderSnapshot(providerSnapshot) + } if plan != nil { b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit)) } @@ -158,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq return order, nil } +func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) { + const maxAttempts = 5 + for attempt := 0; attempt < maxAttempts; attempt++ { + candidate := generateOutTradeNo() + exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx) + if err != nil { + return "", fmt.Errorf("check out_trade_no uniqueness: %w", err) + } + if !exists { + return candidate, nil + } + } + return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts) +} + func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error { if max <= 0 { max = defaultMaxPendingOrders @@ -167,12 +221,71 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return fmt.Errorf("count pending orders: %w", err) } if c >= max { - return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)). + return infraerrors.TooManyRequests("TOO_MANY_PENDING", "too_many_pending"). WithMetadata(map[string]string{"max": strconv.Itoa(max)}) } return nil } +func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any { + if sel == nil { + return nil + } + + snapshot := map[string]any{} + snapshot["schema_version"] = 2 + + instanceID := strings.TrimSpace(sel.InstanceID) + if instanceID != "" { + snapshot["provider_instance_id"] = instanceID + } + + providerKey := strings.TrimSpace(sel.ProviderKey) + if providerKey != "" { + snapshot["provider_key"] = providerKey + } + + paymentMode := strings.TrimSpace(sel.PaymentMode) + if paymentMode != "" { + snapshot["payment_mode"] = paymentMode + } + + if providerKey == payment.TypeWxpay { + if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" { + snapshot["merchant_app_id"] = merchantAppID + } + if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" { + snapshot["merchant_id"] = merchantID + } + snapshot["currency"] = "CNY" + } + if providerKey == payment.TypeAlipay { + if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" { + snapshot["merchant_app_id"] = merchantAppID + } + } + if providerKey == payment.TypeEasyPay { + if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" { + snapshot["merchant_id"] = merchantID + } + } + + if len(snapshot) == 1 { + return nil + } + return snapshot +} + +func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string { + if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay { + return "" + } + if strings.TrimSpace(req.OpenID) != "" { + return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config)) + } + return strings.TrimSpace(sel.Config["appId"]) +} + func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil @@ -191,33 +304,127 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user used += o.Amount } if used+amount > limit { - return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used))) + return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily_limit_exceeded"). + WithMetadata(map[string]string{"remaining": fmt.Sprintf("%.2f", math.Max(0, limit-used))}) } return nil } -func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { - // Select an instance across all providers that support the requested payment type. - // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). - sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) +func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) { + selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req) if err != nil { - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) + return nil, err + } + sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) + if err != nil { + return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured"). + WithMetadata(map[string]string{"payment_type": req.PaymentType}) } if sel == nil { - return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") + return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance") + } + return sel, nil +} + +func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) { + if !requestNeedsWeChatJSAPICompatibility(req) { + return ctx, nil + } + if !s.usesOfficialWxpayVisibleMethod(ctx) { + return ctx, nil + } + expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return nil, err + } + return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil +} + +func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool { + if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return false + } + return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" +} + +func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool { + if s == nil || s.configService == nil { + return false + } + inst, err := s.configService.resolveEnabledVisibleMethodInstance(ctx, payment.TypeWxpay) + if err != nil { + return false } + if inst == nil { + return false + } + return inst.ProviderKey == payment.TypeWxpay +} + +func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") + slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) + // If the provider returned a structured ApplicationError (e.g. WXPAY_CONFIG_MISSING_KEY), + // pass it through with provider context added to metadata. Otherwise wrap as PAYMENT_PROVIDER_MISCONFIGURED. + if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) { + md := map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID} + for k, v := range appErr.Metadata { + md[k] = v + } + return nil, appErr.WithMetadata(md) + } + return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured"). + WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}) } subject := s.buildPaymentSubject(plan, limitAmount, cfg) outTradeNo := order.OutTradeNo - pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) + canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL) if err != nil { - slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) + return nil, err + } + resumeToken := "" + if resume := s.paymentResume(); resume != nil { + if canonicalReturnURL != "" && resume.isSigningConfigured() { + resumeToken, err = resume.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: order.UserID, + ProviderInstanceID: sel.InstanceID, + ProviderKey: sel.ProviderKey, + PaymentType: req.PaymentType, + CanonicalReturnURL: canonicalReturnURL, + }) + if err != nil { + return nil, fmt.Errorf("create payment resume token: %w", err) + } + } } - _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) + providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken) + if err != nil { + return nil, err + } + providerReq := buildProviderCreatePaymentRequest(CreateOrderRequest{ + PaymentType: req.PaymentType, + OpenID: req.OpenID, + ClientIP: req.ClientIP, + IsMobile: req.IsMobile, + ReturnURL: providerReturnURL, + }, sel, outTradeNo, payAmountStr, subject) + pr, err := prov.CreatePayment(ctx, providerReq) + if err != nil { + slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) + if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) { + return nil, appErr + } + return nil, classifyCreatePaymentError(req, sel.ProviderKey, err) + } + _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID). + SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)). + SetNillablePayURL(psNilIfEmpty(pr.PayURL)). + SetNillableQrCode(psNilIfEmpty(pr.QRCode)). + SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)). + SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)). + Save(ctx) if err != nil { return nil, fmt.Errorf("update order with payment details: %w", err) } @@ -227,8 +434,36 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen "payAmount": order.PayAmount, "paymentType": req.PaymentType, "orderType": req.OrderType, + "paymentSource": NormalizePaymentSource(req.PaymentSource), }) - return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil + resultType := pr.ResultType + if resultType == "" { + resultType = payment.CreatePaymentResultOrderCreated + } + resp := buildCreateOrderResponse(order, req, payAmount, sel, pr, resultType) + resp.ResumeToken = resumeToken + return resp, nil +} + +func buildProviderCreatePaymentRequest(req CreateOrderRequest, sel *payment.InstanceSelection, orderID, amount, subject string) payment.CreatePaymentRequest { + return payment.CreatePaymentRequest{ + OrderID: orderID, + Amount: amount, + PaymentType: req.PaymentType, + Subject: subject, + ReturnURL: req.ReturnURL, + OpenID: strings.TrimSpace(req.OpenID), + ClientIP: req.ClientIP, + IsMobile: req.IsMobile, + InstanceSubMethods: selectedInstanceSupportedTypes(sel), + } +} + +func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string { + if sel == nil { + return "" + } + return sel.SupportedTypes } func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string { @@ -247,6 +482,193 @@ func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limit return "Sub2API " + amountStr + " CNY" } +func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) { + return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil) +} + +func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponseForSelection(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { + if sel != nil && sel.ProviderKey != "" && sel.ProviderKey != payment.TypeWxpay { + return nil, nil + } + if strings.TrimSpace(req.OpenID) != "" || !req.IsWeChatBrowser || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return nil, nil + } + return s.buildWeChatOAuthRequiredResponse(ctx, req, amount, payAmount, feeRate) +} + +func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) { + appID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return nil, err + } + if err := s.paymentResume().ensureSigningKey(); err != nil { + return nil, err + } + + authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base") + if err != nil { + return nil, err + } + + return &CreateOrderResponse{ + Amount: amount, + PayAmount: payAmount, + FeeRate: feeRate, + ResultType: payment.CreatePaymentResultOAuthRequired, + PaymentType: req.PaymentType, + OAuth: &payment.WechatOAuthInfo{ + AuthorizeURL: authorizeURL, + AppID: appID, + Scope: "snsapi_base", + RedirectURL: "/auth/wechat/payment/callback", + }, + }, nil +} + +func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context, req CreateOrderRequest, sel *payment.InstanceSelection) error { + if !requiresWeChatJSAPICompatibleSelection(req, sel) { + return nil + } + expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx) + if err != nil { + return err + } + selectedAppID := provider.ResolveWxpayJSAPIAppID(sel.Config) + if selectedAppID == "" || selectedAppID != expectedAppID { + return infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "selected payment instance is not compatible with the current WeChat OAuth app") + } + return nil +} + +func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool { + if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay { + return false + } + return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" +} + +func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) { + if s == nil || s.configService == nil || s.configService.settingRepo == nil { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx) + appID := strings.TrimSpace(cfg.AppIDForMode("mp")) + appSecret := strings.TrimSpace(cfg.AppSecretForMode("mp")) + if err != nil || !cfg.SupportsMode("mp") || appID == "" || appSecret == "" { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + return appID, appSecret, nil +} + +func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error { + if err == nil { + return nil + } + if providerKey == payment.TypeWxpay && + payment.GetBasePaymentType(req.PaymentType) == payment.TypeWxpay && + strings.Contains(err.Error(), "wxpay h5 payments are not authorized for this merchant") { + return infraerrors.ServiceUnavailable( + "WECHAT_H5_NOT_AUTHORIZED", + "wechat h5 payment is not available for this merchant", + ).WithMetadata(map[string]string{ + "action": "open_in_wechat_or_scan_qr", + }) + } + return infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) +} + +func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest, payAmount float64, sel *payment.InstanceSelection, pr *payment.CreatePaymentResponse, resultType payment.CreatePaymentResultType) *CreateOrderResponse { + return &CreateOrderResponse{ + OrderID: order.ID, + Amount: order.Amount, + PayAmount: payAmount, + FeeRate: order.FeeRate, + Status: OrderStatusPending, + ResultType: resultType, + PaymentType: req.PaymentType, + OutTradeNo: order.OutTradeNo, + PayURL: pr.PayURL, + QRCode: pr.QRCode, + ClientSecret: pr.ClientSecret, + OAuth: pr.OAuth, + JSAPI: pr.JSAPI, + JSAPIPayload: pr.JSAPI, + ExpiresAt: order.ExpiresAt, + PaymentMode: sel.PaymentMode, + } +} + +func buildWeChatPaymentOAuthStartURL(req CreateOrderRequest, scope string) (string, error) { + u, err := url.Parse("/api/v1/auth/oauth/wechat/payment/start") + if err != nil { + return "", fmt.Errorf("build wechat payment oauth start url: %w", err) + } + q := u.Query() + q.Set("payment_type", strings.TrimSpace(req.PaymentType)) + if req.Amount > 0 { + q.Set("amount", strconv.FormatFloat(req.Amount, 'f', -1, 64)) + } + if orderType := strings.TrimSpace(req.OrderType); orderType != "" { + q.Set("order_type", orderType) + } + if req.PlanID > 0 { + q.Set("plan_id", strconv.FormatInt(req.PlanID, 10)) + } + if scope = strings.TrimSpace(scope); scope != "" { + q.Set("scope", scope) + } + if redirectTo := paymentRedirectPathFromURL(req.SrcURL); redirectTo != "" { + q.Set("redirect", redirectTo) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +func paymentRedirectPathFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "/purchase" + } + if strings.HasPrefix(rawURL, "/") && !strings.HasPrefix(rawURL, "//") { + return normalizePaymentRedirectPath(rawURL) + } + u, err := url.Parse(rawURL) + if err != nil { + return "/purchase" + } + path := strings.TrimSpace(u.EscapedPath()) + if path == "" { + path = strings.TrimSpace(u.Path) + } + if path == "" || !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") { + return "/purchase" + } + if strings.TrimSpace(u.RawQuery) != "" { + path += "?" + u.RawQuery + } + return normalizePaymentRedirectPath(path) +} + +func normalizePaymentRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "/purchase" + } + if path == "/payment" { + return "/purchase" + } + if strings.HasPrefix(path, "/payment?") { + return "/purchase" + strings.TrimPrefix(path, "/payment") + } + return path +} + // --- Order Queries --- func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) { diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8c5e4fc0e34e8a6a5cebcb2b025695eefca952bd --- /dev/null +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -0,0 +1,98 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{entClient: client}, + } + + if !svc.usesOfficialWxpayVisibleMethod(ctx) { + t.Fatal("expected official wxpay visible method to be detected from enabled provider instance") + } +} + +func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) { + tests := []struct { + name string + source string + wantOfficial bool + }{ + { + name: "official source selected", + source: VisibleMethodSourceOfficialWechat, + wantOfficial: true, + }, + { + name: "easypay source selected", + source: VisibleMethodSourceEasyPayWechat, + wantOfficial: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay wxpay instance: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodWxpaySource: tt.source, + }, + }, + }, + } + + if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial { + t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial) + } + }) + } +} diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index 801471804c7597ea667ca86279b35177a3aec95f..b627ced4ecc0907cc73efc69334ea4b23bbe4e66 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -139,34 +140,123 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s if err != nil { return "" } - // Use OutTradeNo as fallback when PaymentTradeNo is empty - // (e.g. EasyPay popup mode where trade_no arrives only via notify callback) - tradeNo := o.PaymentTradeNo - if tradeNo == "" { - tradeNo = o.OutTradeNo + queryRef := paymentOrderQueryReference(o, prov) + if queryRef == "" { + return "" } - resp, err := prov.QueryOrder(ctx, tradeNo) + resp, err := prov.QueryOrder(ctx, queryRef) if err != nil { slog.Warn("query upstream failed", "orderID", o.ID, "error", err) return "" } if resp.Status == payment.ProviderStatusPaid { - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { + if !isValidProviderAmount(resp.Amount) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{ + "expected": o.PayAmount, + "paid": resp.Amount, + "tradeNo": resp.TradeNo, + "queryRef": queryRef, + }) + slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount) + retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef) + if !retryOK { + return "" + } + resp = retriedResp + } + notificationTradeNo := o.PaymentTradeNo + if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) { + if _, updateErr := s.entClient.PaymentOrder.Update(). + Where(paymentorder.IDEQ(o.ID)). + SetPaymentTradeNo(upstreamTradeNo). + Save(ctx); updateErr != nil { + slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr) + } else { + o.PaymentTradeNo = upstreamTradeNo + } + notificationTradeNo = upstreamTradeNo + } + if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil { slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) // Still return already_paid — order was paid, fulfillment can be retried } return checkPaidResultAlreadyPaid } if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, tradeNo) + _ = cp.CancelPayment(ctx, queryRef) } return "" } +func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) { + if prov == nil || strings.TrimSpace(queryRef) == "" { + return nil, false + } + resp, err := prov.QueryOrder(ctx, queryRef) + if err != nil { + slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err) + return nil, false + } + if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) { + return nil, false + } + return resp, true +} + +func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string { + if order == nil { + return "" + } + + providerKey := "" + if prov != nil { + providerKey = strings.TrimSpace(prov.ProviderKey()) + } + if providerKey == "" { + if snapshot := psOrderProviderSnapshot(order); snapshot != nil { + providerKey = strings.TrimSpace(snapshot.ProviderKey) + } + } + if providerKey == "" { + providerKey = strings.TrimSpace(psStringValue(order.ProviderKey)) + } + if providerKey == "" { + providerKey = strings.TrimSpace(order.PaymentType) + } + + switch payment.GetBasePaymentType(providerKey) { + case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay: + return strings.TrimSpace(order.OutTradeNo) + default: + if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" { + return tradeNo + } + return strings.TrimSpace(order.OutTradeNo) + } +} + +func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool { + upstreamTradeNo = strings.TrimSpace(upstreamTradeNo) + if upstreamTradeNo == "" { + return false + } + if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) { + return false + } + if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) { + return false + } + return true +} + // VerifyOrderByOutTradeNo actively queries the upstream provider to check // if a payment was made, and processes it if so. This handles the case where // the provider's notify callback was missed (e.g. EasyPay popup mode). func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) @@ -190,25 +280,42 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo return o, nil } -// VerifyOrderPublic verifies payment status without user authentication. -// Used by the payment result page when the user's session has expired. +// VerifyOrderPublic returns the currently persisted public order state without +// triggering any upstream reconciliation. Signed resume-token recovery is the +// only public recovery path allowed to query upstream state. func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) if err != nil { return nil, infraerrors.NotFound("NOT_FOUND", "order not found") } - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == checkPaidResultAlreadyPaid { - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } + return o, nil +} + +func normalizeOrderLookupOutTradeNo(raw string) (string, error) { + outTradeNo := strings.TrimSpace(raw) + if outTradeNo == "" { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required") + } + if len(outTradeNo) > 64 { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") + } + for _, ch := range outTradeNo { + switch { + case ch >= 'a' && ch <= 'z': + case ch >= 'A' && ch <= 'Z': + case ch >= '0' && ch <= '9': + case ch == '_' || ch == '-': + default: + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") } } - return o, nil + return outTradeNo, nil } func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { @@ -236,22 +343,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) + } + if !paymentOrderAllowsRegistryFallback(o) { + return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID) + } + providerKey := paymentOrderFallbackProviderKey(s.registry, o) + if providerKey == "" { + return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } + +func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool { + if order == nil { + return false + } + if psOrderProviderSnapshot(order) != nil { + return false + } + if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" { + return false + } + if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" { + return false + } + return true +} + +func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string { + if order == nil { + return "" + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType))) +} + +func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { + if inst == nil { + return nil, fmt.Errorf("payment provider instance is missing") + } + + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + if inst.PaymentMode != "" { + cfg["paymentMode"] = inst.PaymentMode + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create provider from instance: %w", err) + } + return prov, nil +} + +func psStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8dfd2e7e01cdd1e72f78205a3fe12f37141fd435 --- /dev/null +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -0,0 +1,575 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type paymentOrderLifecycleQueryProvider struct { + lastQueryTradeNo string + queryCalls int + responses []*payment.QueryOrderResponse + resp *payment.QueryOrderResponse +} + +type paymentOrderLifecycleRedeemRepo struct { + codesByCode map[string]*RedeemCode + useCalls []struct { + id int64 + userID int64 + } +} + +func (p *paymentOrderLifecycleQueryProvider) Name() string { + return "payment-order-lifecycle-query-provider" +} + +func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay } + +func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.TypeAlipay} +} + +func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} + +func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + p.lastQueryTradeNo = tradeNo + p.queryCalls++ + if len(p.responses) > 0 { + resp := p.responses[0] + if len(p.responses) > 1 { + p.responses = p.responses[1:] + } + return resp, nil + } + return p.resp, nil +} + +func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} + +func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) { + for _, code := range r.codesByCode { + if code.ID != id { + continue + } + cloned := *code + return &cloned, nil + } + return nil, ErrRedeemCodeNotFound +} + +func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) { + redeemCode, ok := r.codesByCode[code] + if !ok { + return nil, ErrRedeemCodeNotFound + } + cloned := *redeemCode + return &cloned, nil +} + +func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error { + for code, redeemCode := range r.codesByCode { + if redeemCode.ID != id { + continue + } + now := time.Now().UTC() + redeemCode.Status = StatusUsed + redeemCode.UsedBy = &userID + redeemCode.UsedAt = &now + r.codesByCode[code] = redeemCode + r.useCalls = append(r.useCalls, struct { + id int64 + userID int64 + }{id: id, userID: userID}) + return nil + } + return ErrRedeemCodeNotFound +} + +func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected call") +} + +func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected call") +} + +func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO"). + SetOutTradeNo("sub2_checkpaid_trade_no_missing"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-123", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, OrderStatusCompleted, got.Status) + require.Equal(t, "upstream-trade-123", got.PaymentTradeNo) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCompleted, reloaded.Status) + require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo) + + require.Equal(t, 88.0, userRepo.getByIDUser.Balance) + require.Len(t, redeemRepo.useCalls, 1) + require.Equal(t, int64(1), redeemRepo.useCalls[0].id) + require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) +} + +func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-retry@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-retry-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-UPSTREAM-RETRY"). + SetOutTradeNo("sub2_checkpaid_retry_zero_amount"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + responses: []*payment.QueryOrderResponse{ + { + TradeNo: "upstream-trade-zero", + Status: payment.ProviderStatusPaid, + Amount: 0, + }, + { + TradeNo: "upstream-trade-retry", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, 2, provider.queryCalls) + require.Equal(t, OrderStatusCompleted, got.Status) + require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo) +} + +func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-zero-amount@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-zero-amount-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-ZERO-AMOUNT"). + SetOutTradeNo("sub2_checkpaid_zero_amount"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-zero", + Status: payment.ProviderStatusPaid, + Amount: 0, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, OrderStatusPending, got.Status) + require.Empty(t, got.PaymentTradeNo) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusPending, reloaded.Status) + require.Empty(t, reloaded.PaymentTradeNo) + + require.Equal(t, 0.0, userRepo.getByIDUser.Balance) + require.Empty(t, redeemRepo.useCalls) +} + +func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-existing-trade@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-existing-trade-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO"). + SetOutTradeNo("sub2_checkpaid_use_out_trade_no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("upstream-trade-existing"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: "upstream-trade-existing", + Status: payment.ProviderStatusPaid, + Amount: 88, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo) +} + +func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) { + t.Parallel() + + require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + })) + + instanceID := "12" + require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderInstanceID: &instanceID, + })) + + require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ + PaymentType: payment.TypeAlipay, + ProviderSnapshot: map[string]any{ + "schema_version": 2, + "provider_instance_id": "12", + }, + })) +} + +func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) { + t.Parallel() + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + OutTradeNo: "sub2_out_trade_no", + PaymentTradeNo: "wx-transaction-id", + } + + require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{})) + require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{ + key: payment.TypeWxpay, + })) +} + +func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go new file mode 100644 index 0000000000000000000000000000000000000000..bb60f9e25b884bb93f1c611a34c4aa9375ce1514 --- /dev/null +++ b/backend/internal/service/payment_order_provider_snapshot.go @@ -0,0 +1,205 @@ +package service + +import ( + "context" + "fmt" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +type paymentOrderProviderSnapshot struct { + SchemaVersion int + ProviderInstanceID string + ProviderKey string + PaymentMode string + MerchantAppID string + MerchantID string + Currency string +} + +func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot { + if order == nil || len(order.ProviderSnapshot) == 0 { + return nil + } + + snapshot := &paymentOrderProviderSnapshot{ + SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]), + ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]), + ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]), + PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]), + MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]), + MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]), + Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]), + } + if snapshot.SchemaVersion == 0 && + snapshot.ProviderInstanceID == "" && + snapshot.ProviderKey == "" && + snapshot.PaymentMode == "" && + snapshot.MerchantAppID == "" && + snapshot.MerchantID == "" && + snapshot.Currency == "" { + return nil + } + return snapshot +} + +func psSnapshotStringValue(value any) string { + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + default: + return "" + } +} + +func psSnapshotIntValue(value any) int { + switch typed := value.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float32: + return int(typed) + case float64: + return int(typed) + case string: + n, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return n + } + } + return 0 +} + +func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil || order == nil || snapshot == nil { + return nil, nil + } + + snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID) + columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) + if snapshotInstanceID == "" { + snapshotInstanceID = columnInstanceID + } + if snapshotInstanceID == "" { + return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID) + } + if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) { + return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID) + } + + instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64) + if err != nil { + return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID) + } + + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID) + } + return nil, err + } + + if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) { + return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey) + } + + return inst, nil +} + +func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string { + if order == nil { + return strings.TrimSpace(instanceProviderKey) + } + + orderProviderKey := psStringValue(order.ProviderKey) + if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" { + orderProviderKey = snapshot.ProviderKey + } + + return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey) +} + +func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error { + if order == nil || len(metadata) == 0 { + return nil + } + + snapshot := psOrderProviderSnapshot(order) + if snapshot == nil { + return nil + } + + switch strings.TrimSpace(providerKey) { + case payment.TypeWxpay: + if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" { + actual := strings.TrimSpace(metadata["appid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing appid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" { + actual := strings.TrimSpace(metadata["mchid"]) + if actual == "" { + return fmt.Errorf("wxpay notification missing mchid") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual) + } + } + if expected := strings.TrimSpace(snapshot.Currency); expected != "" { + actual := strings.ToUpper(strings.TrimSpace(metadata["currency"])) + if actual == "" { + return fmt.Errorf("wxpay notification missing currency") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual) + } + } + if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") { + return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual) + } + case payment.TypeAlipay: + if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" { + actual := strings.TrimSpace(metadata["app_id"]) + if actual == "" { + return fmt.Errorf("alipay app_id missing") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual) + } + } + case payment.TypeEasyPay: + if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" { + actual := strings.TrimSpace(metadata["pid"]) + if actual == "" { + return fmt.Errorf("easypay pid missing") + } + if !strings.EqualFold(expected, actual) { + return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual) + } + } + } + + return nil +} + +func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string { + if prov == nil { + return nil + } + reporter, ok := prov.(payment.MerchantIdentityProvider) + if !ok { + return nil + } + return reporter.MerchantIdentityMetadata() +} diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go new file mode 100644 index 0000000000000000000000000000000000000000..efa013b52dd6e0e1d9f2013c6a320c53d28d8887 --- /dev/null +++ b/backend/internal/service/payment_order_provider_snapshot_test.go @@ -0,0 +1,172 @@ +//go:build unit + +package service + +import ( + "context" + "strconv" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) { + t.Parallel() + + sel := &payment.InstanceSelection{ + InstanceID: "12", + ProviderKey: payment.TypeWxpay, + SupportedTypes: "wxpay,wxpay_direct", + PaymentMode: "popup", + Config: map[string]string{ + "privateKey": "secret", + "apiV3Key": "secret-v3", + "appId": "wx-app-id", + }, + } + + snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{}) + require.Equal(t, map[string]any{ + "schema_version": 2, + "provider_instance_id": "12", + "provider_key": payment.TypeWxpay, + "payment_mode": "popup", + "merchant_app_id": "wx-app-id", + "currency": "CNY", + }, snapshot) + require.NotContains(t, snapshot, "config") + require.NotContains(t, snapshot, "privateKey") + require.NotContains(t, snapshot, "apiV3Key") + require.NotContains(t, snapshot, "supported_types") + require.NotContains(t, snapshot, "instance_name") + require.NotContains(t, snapshot, "merchant_id") +} + +func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("snapshot@example.com"). + SetPasswordHash("hash"). + SetUsername("snapshot-user"). + Save(ctx) + require.NoError(t, err) + + instance, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Primary Alipay"). + SetConfig(`{"secretKey":"do-not-copy"}`). + SetSupportedTypes("alipay,alipay_direct"). + SetPaymentMode("redirect"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{entClient: client} + order, err := svc.createOrderInTx( + ctx, + CreateOrderRequest{ + UserID: user.ID, + PaymentType: payment.TypeAlipay, + OrderType: payment.OrderTypeBalance, + ClientIP: "127.0.0.1", + SrcHost: "app.example.com", + }, + &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + }, + nil, + &PaymentConfig{ + MaxPendingOrders: 3, + OrderTimeoutMin: 30, + }, + 88, + 88, + 0, + 88, + &payment.InstanceSelection{ + InstanceID: strconv.FormatInt(instance.ID, 10), + ProviderKey: payment.TypeAlipay, + SupportedTypes: "alipay,alipay_direct", + PaymentMode: "redirect", + Config: map[string]string{ + "secretKey": "do-not-copy", + }, + }, + ) + require.NoError(t, err) + require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID)) + require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey)) + require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"]) + require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"]) + require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"]) + require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"]) + require.NotContains(t, order.ProviderSnapshot, "config") + require.NotContains(t, order.ProviderSnapshot, "secretKey") + require.NotContains(t, order.ProviderSnapshot, "supported_types") + require.NotContains(t, order.ProviderSnapshot, "instance_name") +} + +func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "88", + ProviderKey: payment.TypeWxpay, + Config: map[string]string{ + "appId": "wx-open-app", + "mpAppId": "wx-mp-app", + "mchId": "mch-88", + }, + PaymentMode: "jsapi", + }, CreateOrderRequest{OpenID: "openid-123"}) + + require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"]) + require.Equal(t, "mch-88", snapshot["merchant_id"]) + require.Equal(t, "CNY", snapshot["currency"]) +} + +func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "21", + ProviderKey: payment.TypeAlipay, + Config: map[string]string{ + "appId": "alipay-app-21", + "privateKey": "secret", + }, + PaymentMode: "redirect", + }, CreateOrderRequest{}) + + require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"]) + require.NotContains(t, snapshot, "privateKey") +} + +func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) { + t.Parallel() + + snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{ + InstanceID: "66", + ProviderKey: payment.TypeEasyPay, + Config: map[string]string{ + "pid": "easypay-merchant-66", + "pkey": "secret", + }, + PaymentMode: "popup", + }, CreateOrderRequest{PaymentType: payment.TypeAlipay}) + + require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"]) + require.NotContains(t, snapshot, "pkey") +} + +func valueOrEmpty(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2d7412e0612accbad939ebb4ae79343bb342eadb --- /dev/null +++ b/backend/internal/service/payment_order_result_test.go @@ -0,0 +1,276 @@ +package service + +import ( + "context" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) { + t.Parallel() + + expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC) + resp := buildCreateOrderResponse( + &dbent.PaymentOrder{ + ID: 42, + Amount: 12.34, + FeeRate: 0.03, + ExpiresAt: expiresAt, + OutTradeNo: "sub2_42", + }, + CreateOrderRequest{PaymentType: payment.TypeWxpay}, + 12.71, + &payment.InstanceSelection{PaymentMode: "qrcode"}, + &payment.CreatePaymentResponse{ + TradeNo: "sub2_42", + QRCode: "weixin://wxpay/bizpayurl?pr=test", + }, + payment.CreatePaymentResultOrderCreated, + ) + + if resp.ResultType != payment.CreatePaymentResultOrderCreated { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated) + } + if resp.OutTradeNo != "sub2_42" { + t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42") + } + if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" { + t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test") + } + if resp.JSAPI != nil || resp.JSAPIPayload != nil { + t.Fatal("order_created response should not include jsapi payload") + } + if !resp.ExpiresAt.Equal(expiresAt) { + t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt) + } +} + +func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { + t.Parallel() + + jsapiPayload := &payment.WechatJSAPIPayload{ + AppID: "wx123", + TimeStamp: "1712345678", + NonceStr: "nonce-123", + Package: "prepay_id=wx123", + SignType: "RSA", + PaySign: "signed-payload", + } + resp := buildCreateOrderResponse( + &dbent.PaymentOrder{ + ID: 88, + Amount: 66.88, + FeeRate: 0.01, + ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC), + OutTradeNo: "sub2_88", + }, + CreateOrderRequest{PaymentType: payment.TypeWxpay}, + 67.55, + &payment.InstanceSelection{PaymentMode: "popup"}, + &payment.CreatePaymentResponse{ + TradeNo: "sub2_88", + ResultType: payment.CreatePaymentResultJSAPIReady, + JSAPI: jsapiPayload, + }, + payment.CreatePaymentResultJSAPIReady, + ) + + if resp.ResultType != payment.CreatePaymentResultJSAPIReady { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady) + } + if resp.JSAPI == nil || resp.JSAPIPayload == nil { + t.Fatal("expected jsapi payload aliases to be populated") + } + if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload { + t.Fatal("expected jsapi aliases to preserve the original pointer") + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatal("expected oauth_required response, got nil") + } + if resp.ResultType != payment.CreatePaymentResultOAuthRequired { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired) + } + if resp.OAuth == nil { + t.Fatal("expected oauth payload, got nil") + } + if resp.OAuth.AppID != "wx123456" { + t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456") + } + if resp.OAuth.Scope != "snsapi_base" { + t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base") + } + if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" { + t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback") + } + if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" { + t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL) + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) { + t.Parallel() + + svc := newWeChatPaymentOAuthTestService(nil) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } + if err == nil { + t.Fatal("expected error, got nil") + } + + appErr := infraerrors.FromError(err) + if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" { + t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED") + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) { + t.Parallel() + + svc := &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }}, + // Intentionally missing payment resume signing key. + encryptionKey: nil, + }, + } + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } + if err == nil { + t.Fatal("expected error, got nil") + } + + appErr := infraerrors.FromError(err) + if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" { + t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED") + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) { + svc := &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }}, + // Legacy stable signing key remains available for no-config upgrade compatibility. + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if resp == nil { + t.Fatal("expected oauth-required response, got nil") + } + if resp.ResultType != payment.CreatePaymentResultOAuthRequired { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired) + } + if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" { + t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth) + } +} + +func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03, &payment.InstanceSelection{ + ProviderKey: payment.TypeEasyPay, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("expected nil response, got %+v", resp) + } +} + +func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService { + return &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: values}, + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } +} diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index c5bda763cd96324a5fda22ba91b6da14c93964ff..7521878c7dd5520f832b3ac9e3c719b7623b7033 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -19,18 +20,133 @@ import ( // --- Refund Flow --- // getOrderProviderInstance looks up the provider instance that processed this order. -// Returns nil, nil for legacy orders without provider_instance_id. +// For legacy orders without provider_instance_id, it resolves only when the +// historical instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { - if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + if s == nil || s.entClient == nil || o == nil { return nil, nil } - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + + if snapshot := psOrderProviderSnapshot(o); snapshot != nil { + return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot) + } + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) if err != nil { return nil, nil } return s.entClient.PaymentProviderInstance.Get(ctx, instID) } +// getRefundOrderProviderInstance resolves the provider instance for refund paths. +// Refunds must be pinned to an explicit historical binding, so legacy +// "best-effort" provider guessing is intentionally not allowed here. +func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil || o == nil { + return nil, nil + } + + if snapshot := psOrderProviderSnapshot(o); snapshot != nil { + return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot) + } + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return nil, nil + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr) + } + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr) + } + return nil, err + } + return inst, nil +} + +func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey != "" { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.ProviderKeyEQ(providerKey)). + All(ctx) + if err != nil { + return nil, err + } + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil + } + + if paymentType == "" { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + All(ctx) + if err != nil { + return nil, err + } + + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance { + if len(instances) == 0 { + return nil + } + if strings.TrimSpace(orderPaymentType) == "" { + return instances + } + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(orderPaymentType, inst) { + matched = append(matched, inst) + } + } + return matched +} + +func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { + if inst == nil { + return false + } + + baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType)) + instanceProviderKey := strings.TrimSpace(inst.ProviderKey) + if baseType == "" { + return false + } + + if baseType == payment.TypeStripe { + return instanceProviderKey == payment.TypeStripe + } + if instanceProviderKey == payment.TypeStripe { + return false + } + if instanceProviderKey == baseType { + return true + } + return payment.InstanceSupportsType(inst.SupportedTypes, baseType) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { @@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } // Check provider instance allows user refund - inst, err := s.getOrderProviderInstance(ctx, o) + inst, err := s.getRefundOrderProviderInstance(ctx, o) if err != nil || inst == nil { return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order") } @@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } // Check provider instance allows admin refund - inst, instErr := s.getOrderProviderInstance(ctx, o) + inst, instErr := s.getRefundOrderProviderInstance(ctx, o) if instErr != nil { slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr) return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order") @@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { if err != nil { return fmt.Errorf("get refund provider: %w", err) } + if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil { + s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{ + "detail": err.Error(), + }) + return err + } _, err = prov.Refund(ctx, payment.RefundRequest{ TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, @@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error { // getRefundProvider creates a provider using the order's original instance config. // Delegates to getOrderProvider which handles instance lookup and fallback. func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - return s.getOrderProvider(ctx, o) + inst, err := s.getRefundOrderProviderInstance(ctx, o) + if err != nil { + return nil, err + } + if inst == nil { + return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID) + } + return s.createProviderFromInstance(ctx, inst) } func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) { diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ca5b62cb28d94551d14779d46c860e808b2e68a4 --- /dev/null +++ b/backend/internal/service/payment_refund_test.go @@ -0,0 +1,186 @@ +//go:build unit + +package service + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-legacy@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-legacy-user"). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-instance"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetAllowUserRefund(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("REFUND-LEGACY-ORDER"). + SetOutTradeNo("sub2_refund_legacy_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-legacy-refund"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + } + + _, err = svc.validateRefundRequest(ctx, order.ID, user.ID) + require.Error(t, err) + require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err)) +} + +func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-legacy-admin@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-legacy-admin-user"). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-admin-instance"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetAllowUserRefund(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(188). + SetPayAmount(188). + SetFeeRate(0). + SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER"). + SetOutTradeNo("sub2_refund_legacy_admin_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-legacy-admin-refund"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + } + + plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false) + require.Nil(t, plan) + require.Nil(t, result) + require.Error(t, err) + require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err)) +} + +func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + user, err := client.User.Create(). + SetEmail("refund-snapshot-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("refund-snapshot-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-refund-mismatch-instance"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{ + "appId": "runtime-alipay-app", + "privateKey": "runtime-private-key", + })). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetRefundEnabled(true). + Save(ctx) + require.NoError(t, err) + + instID := strconv.FormatInt(inst.ID, 10) + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER"). + SetOutTradeNo("sub2_refund_snapshot_mismatch_order"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-refund-snapshot-mismatch"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusCompleted). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instID). + SetProviderKey(payment.TypeAlipay). + SetProviderSnapshot(map[string]any{ + "schema_version": 2, + "provider_instance_id": instID, + "provider_key": payment.TypeAlipay, + "merchant_app_id": "expected-alipay-app", + }). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + err = svc.gwRefund(ctx, &RefundPlan{ + OrderID: order.ID, + Order: order, + RefundAmount: order.Amount, + GatewayAmount: order.Amount, + Reason: "snapshot mismatch", + }) + require.ErrorContains(t, err, "alipay app_id mismatch") +} diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff061e8cac2df3e503b639608fab195209ef710 --- /dev/null +++ b/backend/internal/service/payment_resume_lookup.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { + claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token)) + if err != nil { + return nil, err + } + + order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.NotFound("NOT_FOUND", "order not found") + } + return nil, fmt.Errorf("get order by resume token: %w", err) + } + if claims.UserID > 0 && order.UserID != claims.UserID { + return nil, invalidResumeTokenMatchError() + } + snapshot := psOrderProviderSnapshot(order) + orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) + orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey)) + if snapshot != nil { + if snapshot.ProviderInstanceID != "" { + orderProviderInstanceID = snapshot.ProviderInstanceID + } + if snapshot.ProviderKey != "" { + orderProviderKey = snapshot.ProviderKey + } + } + if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { + return nil, invalidResumeTokenMatchError() + } + if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) { + return nil, invalidResumeTokenMatchError() + } + if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) { + return nil, invalidResumeTokenMatchError() + } + if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { + result := s.checkPaid(ctx, order) + if result == checkPaidResultAlreadyPaid { + order, err = s.entClient.PaymentOrder.Get(ctx, order.ID) + if err != nil { + return nil, fmt.Errorf("reload order by resume token: %w", err) + } + } + } + + return order, nil +} + +func invalidResumeTokenMatchError() error { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order") +} + +func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { + return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) +} diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a7b5b7376c6232903143f12bb47e6961baf47058 --- /dev/null +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -0,0 +1,315 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type paymentResumeLookupProvider struct { + queryCount int +} + +func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" } + +func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay } + +func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType { + return []payment.PaymentType{payment.TypeAlipay} +} + +func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} + +func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + p.queryCount++ + return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil +} + +func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} + +func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := "12" + providerKey := payment.TypeEasyPay + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-ORDER"). + SetOutTradeNo("sub2_resume_lookup"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-1"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(providerKey). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: instanceID, + ProviderKey: providerKey, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-MISMATCH"). + SetOutTradeNo("sub2_resume_lookup_mismatch"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-2"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("12"). + SetProviderKey(payment.TypeEasyPay). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "99", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + _, err = svc.GetPublicOrderByResumeToken(ctx, token) + require.Error(t, err) + require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err)) +} + +func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-snapshot-authority@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-snapshot-authority-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY"). + SetOutTradeNo("sub2_resume_snapshot_authority"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-snapshot-authority"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("legacy-column-instance"). + SetProviderKey(payment.TypeAlipay). + SetProviderSnapshot(map[string]any{ + "schema_version": 2, + "provider_instance_id": "snapshot-instance", + "provider_key": payment.TypeEasyPay, + }). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "snapshot-instance", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-refresh@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-refresh-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-PENDING"). + SetOutTradeNo("sub2_resume_lookup_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-pending"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentResumeLookupProvider{} + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + resumeService: resumeSvc, + providersLoaded: true, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) + require.Equal(t, 1, provider.queryCount) +} + +func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("public-verify@example.com"). + SetPasswordHash("hash"). + SetUsername("public-verify-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("PUBLIC-VERIFY"). + SetOutTradeNo("sub2_public_verify_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-verify"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentResumeLookupProvider{} + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) + require.Equal(t, 0, provider.queryCount) +} + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + svc := &PaymentService{ + entClient: newPaymentConfigServiceTestClient(t), + } + + _, err := svc.VerifyOrderPublic(context.Background(), " ") + require.Error(t, err) + require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go new file mode 100644 index 0000000000000000000000000000000000000000..9ae62fde33b84e0c68545133fa840149c4ee4b49 --- /dev/null +++ b/backend/internal/service/payment_resume_service.go @@ -0,0 +1,476 @@ +package service + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const paymentResultReturnPath = "/payment/result" + +const ( + PaymentSourceHostedRedirect = "hosted_redirect" + PaymentSourceWechatInAppResume = "wechat_in_app_resume" + + SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" + SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" + SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" + SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled" + + VisibleMethodSourceOfficialAlipay = "official_alipay" + VisibleMethodSourceEasyPayAlipay = "easypay_alipay" + VisibleMethodSourceOfficialWechat = "official_wxpay" + VisibleMethodSourceEasyPayWechat = "easypay_wxpay" + + wechatPaymentResumeTokenType = "wechat_payment_resume" + + paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED" + paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key" + + paymentResumeTokenTTL = 24 * time.Hour + wechatPaymentResumeTokenTTL = 15 * time.Minute +) + +type ResumeTokenClaims struct { + OrderID int64 `json:"oid"` + UserID int64 `json:"uid,omitempty"` + ProviderInstanceID string `json:"pi,omitempty"` + ProviderKey string `json:"pk,omitempty"` + PaymentType string `json:"pt,omitempty"` + CanonicalReturnURL string `json:"ru,omitempty"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp,omitempty"` +} + +type WeChatPaymentResumeClaims struct { + TokenType string `json:"tk,omitempty"` + OpenID string `json:"openid"` + PaymentType string `json:"pt,omitempty"` + Amount string `json:"amt,omitempty"` + OrderType string `json:"ot,omitempty"` + PlanID int64 `json:"pid,omitempty"` + RedirectTo string `json:"rd,omitempty"` + Scope string `json:"scp,omitempty"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp,omitempty"` +} + +type PaymentResumeService struct { + signingKey []byte + verifyKeys [][]byte +} + +type visibleMethodLoadBalancer struct { + inner payment.LoadBalancer + configService *PaymentConfigService +} + +func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService { + svc := &PaymentResumeService{} + if len(signingKey) > 0 { + svc.signingKey = append([]byte(nil), signingKey...) + svc.verifyKeys = append(svc.verifyKeys, svc.signingKey) + } + for _, fallback := range verifyFallbacks { + if len(fallback) == 0 { + continue + } + cloned := append([]byte(nil), fallback...) + duplicate := false + for _, existing := range svc.verifyKeys { + if bytes.Equal(existing, cloned) { + duplicate = true + break + } + } + if !duplicate { + svc.verifyKeys = append(svc.verifyKeys, cloned) + } + } + return svc +} + +func (s *PaymentResumeService) isSigningConfigured() bool { + return s != nil && len(s.signingKey) > 0 +} + +func (s *PaymentResumeService) ensureSigningKey() error { + if s.isSigningConfigured() { + return nil + } + return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage) +} + +func NormalizeVisibleMethod(method string) string { + return payment.GetBasePaymentType(strings.TrimSpace(method)) +} + +func NormalizeVisibleMethods(methods []string) []string { + if len(methods) == 0 { + return nil + } + seen := make(map[string]struct{}, len(methods)) + out := make([]string, 0, len(methods)) + for _, method := range methods { + normalized := NormalizeVisibleMethod(method) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out +} + +func NormalizePaymentSource(source string) string { + switch strings.TrimSpace(strings.ToLower(source)) { + case "", PaymentSourceHostedRedirect: + return PaymentSourceHostedRedirect + case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume: + return PaymentSourceWechatInAppResume + default: + return strings.TrimSpace(strings.ToLower(source)) + } +} + +func NormalizeVisibleMethodSource(method, source string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official": + return VisibleMethodSourceOfficialAlipay + case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayAlipay + } + case payment.TypeWxpay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official": + return VisibleMethodSourceOfficialWechat + case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayWechat + } + } + return "" +} + +func VisibleMethodProviderKeyForSource(method, source string) (string, bool) { + switch NormalizeVisibleMethodSource(method, source) { + case VisibleMethodSourceOfficialAlipay: + return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceEasyPayAlipay: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceOfficialWechat: + return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay + case VisibleMethodSourceEasyPayWechat: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay + default: + return "", false + } +} + +func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer { + if inner == nil || configService == nil || configService.entClient == nil { + return inner + } + return &visibleMethodLoadBalancer{inner: inner, configService: configService} +} + +func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { + return lb.inner.GetInstanceConfig(ctx, instanceID) +} + +func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) { + visibleMethod := NormalizeVisibleMethod(paymentType) + if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) { + return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount) + } + + inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod) + if err != nil { + return nil, err + } + if inst == nil { + return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod) + } + return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount) +} + +func visibleMethodEnabledSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipayEnabled + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpayEnabled + default: + return "" + } +} + +func visibleMethodSourceSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipaySource + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpaySource + default: + return "" + } +} + +func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + parsed, err := url.Parse(raw) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https") + } + parsed.Fragment = "" + if parsed.Path == "" { + parsed.Path = "/" + } + if parsed.Path != paymentResultReturnPath { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page") + } + if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin") + } + return parsed.String(), nil +} + +func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool { + if sameOriginHost(returnURLHost, requestHost) { + return true + } + + refererURL = strings.TrimSpace(refererURL) + if refererURL == "" { + return false + } + parsedReferer, err := url.Parse(refererURL) + if err != nil || parsedReferer.Host == "" { + return false + } + return sameOriginHost(returnURLHost, parsedReferer.Host) +} + +func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) { + canonical := strings.TrimSpace(base) + if canonical == "" { + return "", nil + } + + parsed, err := url.Parse(canonical) + if err != nil { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL") + } + if !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL") + } + parsed.Fragment = "" + + query := parsed.Query() + if orderID > 0 { + query.Set("order_id", strconv.FormatInt(orderID, 10)) + } + if strings.TrimSpace(outTradeNo) != "" { + query.Set("out_trade_no", strings.TrimSpace(outTradeNo)) + } + if strings.TrimSpace(resumeToken) != "" { + query.Set("resume_token", strings.TrimSpace(resumeToken)) + } + query.Set("status", "success") + parsed.RawQuery = query.Encode() + + return parsed.String(), nil +} + +func sameOriginHost(returnURLHost string, requestHost string) bool { + returnHost := strings.TrimSpace(returnURLHost) + reqHost := strings.TrimSpace(requestHost) + if returnHost == "" || reqHost == "" { + return false + } + if strings.EqualFold(returnHost, reqHost) { + return true + } + + returnName, returnPort := splitHostPortDefault(returnHost) + reqName, reqPort := splitHostPortDefault(reqHost) + if returnName == "" || reqName == "" { + return false + } + return strings.EqualFold(returnName, reqName) && returnPort == reqPort +} + +func splitHostPortDefault(raw string) (string, string) { + if host, port, err := net.SplitHostPort(raw); err == nil { + return host, port + } + return raw, "" +} + +func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } + if claims.OrderID <= 0 { + return "", fmt.Errorf("resume token requires order id") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + if claims.ExpiresAt == 0 { + claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix() + } + return s.createSignedToken(claims) +} + +func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } + var claims ResumeTokenClaims + if err := s.parseSignedToken(token, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") + } + if claims.OrderID <= 0 { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id") + } + if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil { + return nil, err + } + return &claims, nil +} + +func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) { + if err := s.ensureSigningKey(); err != nil { + return "", err + } + claims.OpenID = strings.TrimSpace(claims.OpenID) + if claims.OpenID == "" { + return "", fmt.Errorf("wechat payment resume token requires openid") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + if claims.ExpiresAt == 0 { + claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix() + } + if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" { + claims.PaymentType = normalized + } + if claims.PaymentType == "" { + claims.PaymentType = payment.TypeWxpay + } + if claims.OrderType == "" { + claims.OrderType = payment.OrderTypeBalance + } + claims.TokenType = wechatPaymentResumeTokenType + return s.createSignedToken(claims) +} + +func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { + if err := s.ensureSigningKey(); err != nil { + return nil, err + } + var claims WeChatPaymentResumeClaims + if err := s.parseSignedToken(token, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid") + } + if claims.TokenType != wechatPaymentResumeTokenType { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch") + } + claims.OpenID = strings.TrimSpace(claims.OpenID) + if claims.OpenID == "" { + return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid") + } + if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil { + return nil, err + } + if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" { + claims.PaymentType = normalized + } + if claims.PaymentType == "" { + claims.PaymentType = payment.TypeWxpay + } + if claims.OrderType == "" { + claims.OrderType = payment.OrderTypeBalance + } + return &claims, nil +} + +func (s *PaymentResumeService) createSignedToken(claims any) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal resume claims: %w", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + return encodedPayload + "." + s.sign(encodedPayload), nil +} + +func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { + parts := strings.Split(token, ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") + } + if !s.verifySignature(parts[0], parts[1]) { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed") + } + return json.Unmarshal(payload, dest) +} + +func (s *PaymentResumeService) verifySignature(payload string, signature string) bool { + if s == nil { + return false + } + for _, key := range s.verifyKeys { + if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) { + return true + } + } + return false +} + +func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { + if expiresAt <= 0 { + return nil + } + if time.Now().Unix() > expiresAt { + return infraerrors.BadRequest(code, message) + } + return nil +} + +func (s *PaymentResumeService) sign(payload string) string { + return signPaymentResumePayload(payload, s.signingKey) +} + +func signPaymentResumePayload(payload string, key []byte) string { + mac := hmac.New(sha256.New, key) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7e0adc2de84314d5fab477a3145da7ddc3742a90 --- /dev/null +++ b/backend/internal/service/payment_resume_service_test.go @@ -0,0 +1,808 @@ +//go:build unit + +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/url" + "strconv" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func TestNormalizeVisibleMethods(t *testing.T) { + t.Parallel() + + got := NormalizeVisibleMethods([]string{ + "alipay_direct", + "alipay", + " wxpay_direct ", + "wxpay", + "stripe", + }) + + want := []string{"alipay", "wxpay", "stripe"} + if len(got) != len(want) { + t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestNormalizePaymentSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expect string + }{ + {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect}, + {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume}, + {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizePaymentSource(tt.input); got != tt.expect { + t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestCanonicalizeReturnURL(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "") + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://example.com/payment/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2") + } +} + +func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil { + t.Fatal("CanonicalizeReturnURL should reject relative URLs") + } +} + +func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil { + t.Fatal("CanonicalizeReturnURL should reject external hosts") + } +} + +func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL( + "https://app.example.com/payment/result?from=checkout", + "api.example.com", + "https://app.example.com/purchase", + ) + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://app.example.com/payment/result?from=checkout" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout") + } +} + +func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil { + t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths") + } +} + +func TestBuildPaymentReturnURL(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + + parsed, err := url.Parse(got) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + if parsed.Fragment != "" { + t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment) + } + query := parsed.Query() + if query.Get("from") != "checkout" { + t.Fatalf("expected original query to be preserved, got %q", query.Get("from")) + } + if query.Get("order_id") != strconv.FormatInt(42, 10) { + t.Fatalf("order_id = %q", query.Get("order_id")) + } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q", query.Get("out_trade_no")) + } + if query.Get("resume_token") != "resume-token" { + t.Fatalf("resume_token = %q", query.Get("resume_token")) + } + if query.Get("status") != "success" { + t.Fatalf("status = %q", query.Get("status")) + } +} + +func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + + parsed, err := url.Parse(got) + if err != nil { + t.Fatalf("url.Parse returned error: %v", err) + } + query := parsed.Query() + if query.Get("order_id") != "42" { + t.Fatalf("order_id = %q", query.Get("order_id")) + } + if query.Get("out_trade_no") != "sub2_42" { + t.Fatalf("out_trade_no = %q", query.Get("out_trade_no")) + } + if query.Get("resume_token") != "" { + t.Fatalf("resume_token = %q, want empty", query.Get("resume_token")) + } +} + +func TestBuildPaymentReturnURLEmptyBase(t *testing.T) { + t.Parallel() + + got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token") + if err != nil { + t.Fatalf("buildPaymentReturnURL returned error: %v", err) + } + if got != "" { + t.Fatalf("buildPaymentReturnURL = %q, want empty string", got) + } +} + +func TestPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + ProviderInstanceID: "19", + ProviderKey: "easypay", + PaymentType: "wxpay", + CanonicalReturnURL: "https://example.com/payment/result", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + claims, err := svc.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken returned error: %v", err) + } + if claims.OrderID != 42 || claims.UserID != 7 { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" { + t.Fatalf("claims provider snapshot mismatch: %+v", claims) + } + if claims.CanonicalReturnURL != "https://example.com/payment/result" { + t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL) + } +} + +func TestCreateTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42}) + if err == nil { + t.Fatal("CreateToken should reject missing signing key") + } +} + +func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7}) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseToken(token) + if err == nil { + t.Fatal("ParseToken should reject tokens when signing key is missing") + } +} + +func TestParseTokenRejectsExpiredToken(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + IssuedAt: time.Now().Add(-25 * time.Hour).Unix(), + ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + _, err = svc.ParseToken(token) + if err == nil { + t.Fatal("ParseToken should reject expired tokens") + } +} + +func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + Amount: "12.50", + OrderType: payment.OrderTypeSubscription, + PlanID: 7, + RedirectTo: "/purchase?from=wechat", + Scope: "snsapi_base", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 { + t.Fatalf("claims payment context mismatch: %+v", claims) + } + if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" { + t.Fatalf("claims redirect/scope mismatch: %+v", claims) + } +} + +func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService(nil) + _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"}) + if err == nil { + t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key") + } +} + +func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) { + t.Parallel() + + token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{ + TokenType: wechatPaymentResumeTokenType, + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + }) + svc := NewPaymentResumeService(nil) + _, err := svc.ParseWeChatPaymentResumeToken(token) + if err == nil { + t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing") + } +} + +func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-123", + PaymentType: payment.TypeWxpay, + IssuedAt: time.Now().Add(-30 * time.Minute).Unix(), + ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(), + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + _, err = svc.ParseWeChatPaymentResumeToken(token) + if err == nil { + t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens") + } +} + +func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-explicit-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-explicit-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key") + } +} + +func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + legacyKey := []byte("0123456789abcdef0123456789abcdef") + token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-legacy-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: legacyKey, + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-legacy-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key") + } +} + +func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + legacyKey := []byte("0123456789abcdef0123456789abcdef") + svc := newLegacyAwarePaymentResumeService(legacyKey) + + explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-explicit-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if explicitClaims.OpenID != "openid-explicit-key" { + t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key") + } + + legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-legacy-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if legacyClaims.OpenID != "openid-legacy-key" { + t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key") + } +} + +func TestNormalizeVisibleMethodSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + input string + want string + }{ + {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay}, + {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay}, + {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat}, + {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat}, + {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want { + t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want) + } + }) + } +} + +func TestVisibleMethodProviderKeyForSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + source string + want string + ok bool + }{ + {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true}, + {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true}, + {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true}, + {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true}, + {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source) + if got != tt.want || ok != tt.ok { + t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create alipay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != payment.TypeAlipay { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method payment.PaymentType + officialName string + officialTypes string + easyPayName string + easyPayTypes string + sourceSetting string + wantProvider string + }{ + { + name: "alipay uses official source", + method: payment.TypeAlipay, + officialName: "Official Alipay", + officialTypes: "alipay", + easyPayName: "EasyPay Alipay", + easyPayTypes: "alipay", + sourceSetting: VisibleMethodSourceOfficialAlipay, + wantProvider: payment.TypeAlipay, + }, + { + name: "alipay uses easypay source", + method: payment.TypeAlipay, + officialName: "Official Alipay", + officialTypes: "alipay", + easyPayName: "EasyPay Alipay", + easyPayTypes: "alipay", + sourceSetting: VisibleMethodSourceEasyPayAlipay, + wantProvider: payment.TypeEasyPay, + }, + { + name: "wxpay uses official source", + method: payment.TypeWxpay, + officialName: "Official WeChat", + officialTypes: "wxpay", + easyPayName: "EasyPay WeChat", + easyPayTypes: "wxpay", + sourceSetting: VisibleMethodSourceOfficialWechat, + wantProvider: payment.TypeWxpay, + }, + { + name: "wxpay uses easypay source", + method: payment.TypeWxpay, + officialName: "Official WeChat", + officialTypes: "wxpay", + easyPayName: "EasyPay WeChat", + easyPayTypes: "wxpay", + sourceSetting: VisibleMethodSourceEasyPayWechat, + wantProvider: payment.TypeEasyPay, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + officialProviderKey := payment.TypeAlipay + if tt.method == payment.TypeWxpay { + officialProviderKey = payment.TypeWxpay + } + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(officialProviderKey). + SetName(tt.officialName). + SetConfig("{}"). + SetSupportedTypes(tt.officialTypes). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName(tt.easyPayName). + SetConfig("{}"). + SetSupportedTypes(tt.easyPayTypes). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(tt.method): tt.sourceSetting, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != tt.wantProvider { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider) + } + }) + } +} + +func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(payment.TypeAlipay): "", + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != "" { + t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey) + } + if inner.lastPaymentType != payment.TypeAlipay { + t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method payment.PaymentType + sourceValue string + wantMessage string + }{ + { + name: "invalid wxpay source", + method: payment.TypeWxpay, + sourceValue: "stripe", + wantMessage: "wxpay source must be one of the supported payment providers", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + officialProviderKey := payment.TypeAlipay + officialSupportedTypes := "alipay" + officialName := "Official Alipay" + easyPaySupportedTypes := "alipay" + easyPayName := "EasyPay Alipay" + if tt.method == payment.TypeWxpay { + officialProviderKey = payment.TypeWxpay + officialSupportedTypes = "wxpay" + officialName = "Official WeChat" + easyPaySupportedTypes = "wxpay" + easyPayName = "EasyPay WeChat" + } + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(officialProviderKey). + SetName(officialName). + SetConfig("{}"). + SetSupportedTypes(officialSupportedTypes). + SetEnabled(true). + SetSortOrder(1). + Save(ctx) + if err != nil { + t.Fatalf("create official provider: %v", err) + } + + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName(easyPayName). + SetConfig("{}"). + SetSupportedTypes(easyPaySupportedTypes). + SetEnabled(true). + SetSortOrder(2). + Save(ctx) + if err != nil { + t.Fatalf("create easypay provider: %v", err) + } + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + visibleMethodSourceSettingKey(tt.method): tt.sourceValue, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9) + if err == nil { + t.Fatal("SelectInstance should reject invalid visible method source configuration") + } + if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" { + t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE") + } + if infraerrors.Message(err) != tt.wantMessage { + t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage) + } + }) + } +} + +func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + entClient: newPaymentConfigServiceTestClient(t), + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil { + t.Fatal("SelectInstance should reject when no enabled provider instance exists") + } +} + +type captureLoadBalancer struct { + lastProviderKey string + lastPaymentType string +} + +func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) { + c.lastProviderKey = providerKey + c.lastPaymentType = paymentType + return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil +} + +func mustCreateFallbackSignedToken(t *testing.T, claims any) string { + t.Helper() + + payload, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + mac := hmac.New(sha256.New, []byte("sub2api-payment-resume")) + _, _ = mac.Write([]byte(encodedPayload)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return encodedPayload + "." + signature +} diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 6fc23f974f731731b91356c2bd756f4cbd396aca..97fd76a071d4a535bfb210dca67e1b8ba20e607a 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -1,15 +1,18 @@ package service import ( + "bytes" "context" + "encoding/hex" "fmt" "log/slog" "math/rand/v2" + "os" + "strings" "sync" "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -45,6 +48,8 @@ const ( orderIDPrefix = "sub2_" ) +const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY" + // --- Types --- // generateOutTradeNo creates a unique external order ID for payment providers. @@ -65,29 +70,39 @@ func generateRandomString(n int) string { } type CreateOrderRequest struct { - UserID int64 - Amount float64 - PaymentType string - ClientIP string - IsMobile bool - SrcHost string - SrcURL string - OrderType string - PlanID int64 + UserID int64 + Amount float64 + PaymentType string + OpenID string + ClientIP string + IsMobile bool + IsWeChatBrowser bool + SrcHost string + SrcURL string + ReturnURL string + PaymentSource string + OrderType string + PlanID int64 } type CreateOrderResponse struct { - OrderID int64 `json:"order_id"` - Amount float64 `json:"amount"` - PayAmount float64 `json:"pay_amount"` - FeeRate float64 `json:"fee_rate"` - Status string `json:"status"` - PaymentType string `json:"payment_type"` - PayURL string `json:"pay_url,omitempty"` - QRCode string `json:"qr_code,omitempty"` - ClientSecret string `json:"client_secret,omitempty"` - ExpiresAt time.Time `json:"expires_at"` - PaymentMode string `json:"payment_mode,omitempty"` + OrderID int64 `json:"order_id"` + Amount float64 `json:"amount"` + PayAmount float64 `json:"pay_amount"` + FeeRate float64 `json:"fee_rate"` + Status string `json:"status"` + ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"` + PaymentType string `json:"payment_type"` + OutTradeNo string `json:"out_trade_no,omitempty"` + PayURL string `json:"pay_url,omitempty"` + QRCode string `json:"qr_code,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"` + JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"` + JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + PaymentMode string `json:"payment_mode,omitempty"` + ResumeToken string `json:"resume_token,omitempty"` } type OrderListParams struct { @@ -165,10 +180,13 @@ type PaymentService struct { configService *PaymentConfigService userRepo UserRepository groupRepo GroupRepository + resumeService *PaymentResumeService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc.resumeService = psNewPaymentResumeService(configService) + return svc } // --- Provider Registry --- @@ -219,25 +237,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) { } } -// GetWebhookProvider returns the provider instance that should verify a webhook. -// It extracts out_trade_no from the raw body, looks up the order to find the -// original provider instance, and creates a provider with that instance's credentials. -// Falls back to the registry provider when the order cannot be found. -func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { - if outTradeNo != "" { - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) - if err == nil { - p, pErr := s.getOrderProvider(ctx, order) - if pErr == nil { - return p, nil - } - slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr) - } - } - s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) -} - // --- Helpers --- func psIsRefundStatus(s string) bool { @@ -262,6 +261,60 @@ func psNilIfEmpty(s string) *string { return &s } +func (s *PaymentService) paymentResume() *PaymentResumeService { + if s.resumeService != nil { + return s.resumeService + } + return psNewPaymentResumeService(s.configService) +} + +func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService { + return newLegacyAwarePaymentResumeService(legacyKey) +} + +func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService { + return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService)) +} + +func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService { + signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey) + return NewPaymentResumeService(signingKey, verifyFallbacks...) +} + +func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + +func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) { + signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv)) + if len(signingKey) == 0 { + if len(legacyKey) == 0 { + return nil, nil + } + return legacyKey, nil + } + if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) { + return signingKey, nil + } + return signingKey, [][]byte{legacyKey} +} + +func parsePaymentResumeSigningKey(raw string) []byte { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if len(raw) >= 64 && len(raw)%2 == 0 { + if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 { + return decoded + } + } + return []byte(raw) +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go new file mode 100644 index 0000000000000000000000000000000000000000..899bd7a0203e06a255e9bf3b2b2b49d61c281029 --- /dev/null +++ b/backend/internal/service/payment_visible_method_instances.go @@ -0,0 +1,242 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string { + methodSet := make(map[string]struct{}, 2) + addMethod := func(method string) { + method = NormalizeVisibleMethod(method) + switch method { + case payment.TypeAlipay, payment.TypeWxpay: + methodSet[method] = struct{}{} + } + } + + switch strings.TrimSpace(providerKey) { + case payment.TypeAlipay: + if strings.TrimSpace(supportedTypes) == "" { + addMethod(payment.TypeAlipay) + break + } + for _, supportedType := range splitTypes(supportedTypes) { + if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay { + addMethod(payment.TypeAlipay) + break + } + } + case payment.TypeWxpay: + if strings.TrimSpace(supportedTypes) == "" { + addMethod(payment.TypeWxpay) + break + } + for _, supportedType := range splitTypes(supportedTypes) { + if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay { + addMethod(payment.TypeWxpay) + break + } + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(supportedTypes) { + addMethod(supportedType) + } + } + + methods := make([]string, 0, len(methodSet)) + for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if _, ok := methodSet[method]; ok { + methods = append(methods, method) + } + } + return methods +} + +func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool { + if inst == nil || !inst.Enabled { + return false + } + method = NormalizeVisibleMethod(method) + for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) { + if candidate == method { + return true + } + } + return false +} + +func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance { + filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances)) + for _, inst := range instances { + if providerSupportsVisibleMethod(inst, method) { + filtered = append(filtered, inst) + } + } + return filtered +} + +func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance { + filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances)) + for _, inst := range instances { + if !providerSupportsVisibleMethod(inst, method) { + continue + } + if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) { + continue + } + filtered = append(filtered, inst) + } + return filtered +} + +func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string { + seen := make(map[string]struct{}, len(instances)) + keys := make([]string, 0, len(instances)) + for _, inst := range instances { + if inst == nil { + continue + } + key := strings.TrimSpace(inst.ProviderKey) + if key == "" { + continue + } + normalized := strings.ToLower(key) + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + keys = append(keys, key) + } + return keys +} + +func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" { + return nil + } + for _, inst := range instances { + if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) { + return inst + } + } + return nil +} + +func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( + ctx context.Context, + excludeID int64, + providerKey string, + supportedTypes string, + enabled bool, +) error { + // Visible methods are selected by configured source (official/easypay), + // so multiple enabled providers can intentionally claim the same user-facing + // method. Order creation and limits will route through the configured source. + _, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled + return nil +} + +func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) { + method = NormalizeVisibleMethod(method) + sourceKey := visibleMethodSourceSettingKey(method) + rawSource := "" + if s != nil && s.settingRepo != nil && sourceKey != "" { + value, err := s.settingRepo.GetValue(ctx, sourceKey) + if err != nil { + if !errors.Is(err, ErrSettingNotFound) { + return "", fmt.Errorf("get %s: %w", sourceKey, err) + } + } else { + rawSource = value + } + } + + normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true) + if err != nil { + return "", err + } + if normalizedSource == "" { + return "", nil + } + providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource) + if !ok { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return providerKey, nil +} + +func (s *PaymentConfigService) resolveVisibleMethodProviderKey( + ctx context.Context, + method string, + matching []*dbent.PaymentProviderInstance, +) (string, error) { + switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) { + case 0: + return "", nil + case 1: + return strings.TrimSpace(providerKeys[0]), nil + default: + providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) + if err != nil { + return "", err + } + if providerKey == "" { + return "", nil + } + selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) + if selected == nil { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source has no enabled provider instance", method), + ) + } + return strings.TrimSpace(selected.ProviderKey), nil + } +} + +func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( + ctx context.Context, + method string, +) (*dbent.PaymentProviderInstance, error) { + if s == nil || s.entClient == nil { + return nil, nil + } + + method = NormalizeVisibleMethod(method) + if method != payment.TypeAlipay && method != payment.TypeWxpay { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + Order(paymentproviderinstance.BySortOrder()). + All(ctx) + if err != nil { + return nil, fmt.Errorf("query enabled payment providers: %w", err) + } + + matching := filterEnabledVisibleMethodInstances(instances, method) + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil { + return nil, err + } + if providerKey == "" { + if len(matching) == 0 { + return nil, nil + } + return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil + } + return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil +} diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go new file mode 100644 index 0000000000000000000000000000000000000000..f2da40d9b4c6dfda2314347294b25c3f1f27bbea --- /dev/null +++ b/backend/internal/service/payment_webhook_provider.go @@ -0,0 +1,148 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +// GetWebhookProvider returns the provider instance that should verify a webhook. +// It resolves the original provider instance from the order whenever possible and +// only falls back to a registry provider for legacy/single-instance scenarios. +func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo) + if err != nil { + return nil, err + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers[0], nil +} + +// GetWebhookProviders returns provider candidates that can verify the webhook. +// Official WeChat Pay may require multiple candidates because the callback body +// cannot be bound to a merchant before decryption. +func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) { + if outTradeNo != "" { + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) + if err == nil { + if psHasPinnedProviderInstance(order) { + prov, err := s.getPinnedOrderProvider(ctx, order) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + inst, err := s.getOrderProviderInstance(ctx, order) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + prov, err := s.createProviderFromInstance(ctx, inst) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + s.EnsureProviders(ctx) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil + } + } + + if strings.TrimSpace(providerKey) == payment.TypeWxpay { + return s.getEnabledWebhookProvidersByKey(ctx, providerKey) + } + + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + + s.EnsureProviders(ctx) + prov, err := s.registry.GetProviderByKey(providerKey) + if err != nil { + return nil, err + } + return []payment.Provider{prov}, nil +} + +func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst == nil { + return nil, fmt.Errorf("order %d provider instance is missing", o.ID) + } + return s.createProviderFromInstance(ctx, inst) +} + +func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" || s == nil || s.entClient == nil { + return false + } + + count, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Count(ctx) + if err != nil { + slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err) + return false + } + return count <= 1 +} + +func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { + return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "")) +} + +func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) { + providerKey = strings.TrimSpace(providerKey) + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("query webhook provider instances: %w", err) + } + if len(instances) == 0 { + return nil, payment.ErrProviderNotFound + } + + providers := make([]payment.Provider, 0, len(instances)) + for _, inst := range instances { + prov, provErr := s.createProviderFromInstance(ctx, inst) + if provErr != nil { + slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr) + continue + } + providers = append(providers, prov) + } + if len(providers) == 0 { + return nil, payment.ErrProviderNotFound + } + return providers, nil +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f3efa1f49784ef2989e76b168802156b6365404 --- /dev/null +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -0,0 +1,510 @@ +//go:build unit + +package service + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "strconv" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef" + +type webhookProviderTestDouble struct { + key string + types []payment.PaymentType +} + +func (p webhookProviderTestDouble) Name() string { return p.key } +func (p webhookProviderTestDouble) ProviderKey() string { return p.key } +func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types } +func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + require.NoError(t, err) + + encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey)) + require.NoError(t, err) + return encrypted +} + +func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer { + return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) +} + +func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-" + suffix, + "mchId": "mch-" + suffix, + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-" + suffix, + "certSerial": "cert-serial-" + suffix, + }) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpayDirect, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("easypay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-disabled-legacy"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(false). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-enabled-current"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-only"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeWxpay + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipayDirect, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-snapshot"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 42, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(inst.ID, 10), + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-legacy-fallback"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + ID: 43, + PaymentType: payment.TypeStripe, + ProviderSnapshot: map[string]any{ + "schema_version": 1, + "provider_instance_id": "999999", + "provider_key": payment.TypeStripe, + }, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.Nil(t, got) + require.Error(t, err) + require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing") +} + +func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b") + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig(wxpayConfigA). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-b"). + SetConfig(wxpayConfigB). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "") + require.NoError(t, err) + require.Len(t, providers, 2) +} + +func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-a"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("alipay-b"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + registry: payment.NewRegistry(), + providersLoaded: true, + } + + _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "") + require.Error(t, err) + require.Contains(t, err.Error(), "ambiguous") +} + +func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeStripe, + types: []payment.PaymentType{payment.TypeStripe}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "") + require.NoError(t, err) + require.Len(t, providers, 1) + prov := providers[0] + require.Equal(t, payment.TypeStripe, prov.ProviderKey()) +} + +func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("webhook"). + Save(ctx) + require.NoError(t, err) + + pinnedInstanceID := "999" + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("TEST-RECHARGE"). + SetOutTradeNo("sub2_test_pinned_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(pinnedInstanceID). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeWxpay, + types: []payment.PaymentType{payment.TypeWxpay}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + require.Error(t, err) + require.Contains(t, err.Error(), "provider instance") +} + +func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("snapshot-webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("snapshot-webhook"). + Save(ctx) + require.NoError(t, err) + + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b") + instA, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-a"). + SetConfig(wxpayConfigA). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-snapshot-b"). + SetConfig(wxpayConfigB). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(66). + SetPayAmount(66). + SetFeeRate(0). + SetRechargeCode("SNAPSHOT-WEBHOOK"). + SetOutTradeNo("sub2_test_snapshot_webhook_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderSnapshot(map[string]any{ + "schema_version": 1, + "provider_instance_id": strconv.FormatInt(instA.ID, 10), + "provider_key": payment.TypeWxpay, + "payment_mode": "native", + }). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + registry: payment.NewRegistry(), + providersLoaded: true, + } + + providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order") + require.NoError(t, err) + require.Len(t, providers, 1) + require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey()) +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 2bf48702aa2b5bad79b39cff966eb0221509847e..91a02901da5cb39069ec3e3e432b34cb47de6276 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + // GPT-5.5 回退到 GPT-5.4 定价 + if strings.HasPrefix(model, "gpt-5.5") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) + return openAIGPT54FallbackPricing + } + if strings.HasPrefix(model, "gpt-5.4-mini") { logger.With(zap.String("component", "service.pricing")). Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)")) @@ -812,6 +819,16 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { return openAIGPT54FallbackPricing } + if isOpenAIImageGenerationModel(model) { + for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} { + if pricing, ok := s.pricingData[candidate]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate) + return pricing + } + } + return nil + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 13a5c70c1af1092b9311c6f84b57d60b6e317e4d..e2bd7cf33a5d697f68ae7a46850339d7bd2a61f9 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -128,6 +128,21 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t require.Zero(t, got.LongContextInputTokenThreshold) } +func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) { + imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3} + textPricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-image-2": imagePricing, + "gpt-5.4": textPricing, + }, + } + + got := svc.GetModelPricing("gpt-image-3") + require.Same(t, imagePricing, got) +} + func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { raw := map[string]any{ "gpt-5.4": map[string]any{ diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 53581574bd907a7b9baa28a1333040f5971e7662..4730303f30ed8738e7f74a87bb79d535a95421df 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -1,8 +1,10 @@ package service import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" "net/http" "strconv" @@ -23,6 +25,7 @@ type RateLimitService struct { geminiQuotaService *GeminiQuotaService tempUnschedCache TempUnschedCache timeoutCounterCache TimeoutCounterCache + openAI403CounterCache OpenAI403CounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator usageCacheMu sync.RWMutex @@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface { const geminiPrecheckCacheTTL = time.Minute +const ( + openAI403CooldownMinutesDefault = 10 + openAI403DisableThreshold = 3 + openAI403CounterWindowMinutes = 180 +) + // NewRateLimitService 创建RateLimitService实例 func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { return &RateLimitService{ @@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) { s.timeoutCounterCache = cache } +// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖) +func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) { + s.openAI403CounterCache = cache +} + // SetSettingService 设置系统设置服务(可选依赖) func (s *RateLimitService) SetSettingService(settingService *SettingService) { s.settingService = settingService @@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string { + prefix = strings.TrimSpace(prefix) + if prefix != "" && !strings.HasSuffix(prefix, " ") { + prefix += " " + } + + if msg := strings.TrimSpace(upstreamMsg); msg != "" { + return prefix + msg + } + + rawBody := bytes.TrimSpace(responseBody) + if len(rawBody) > 0 { + if json.Valid(rawBody) { + var compact bytes.Buffer + if err := json.Compact(&compact, rawBody); err == nil { + return prefix + truncateForLog(compact.Bytes(), 512) + } + } + return prefix + truncateForLog(rawBody, 512) + } + + return prefix + fallback +} + // handle403 处理 403 Forbidden 错误 // Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; // 其他平台保持原有 SetError 行为。 @@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst if account.Platform == PlatformAntigravity { return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) } - // 非 Antigravity 平台:保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg + if account.Platform == PlatformOpenAI { + return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody) } + // 非 Antigravity 平台:保持原有行为 + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } +func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) + + if s.openAI403CounterCache == nil { + s.handleAuthError(ctx, account, msg) + return true + } + + count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes) + if err != nil { + slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + if count >= openAI403DisableThreshold { + msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold) + s.handleAuthError(ctx, account, msg) + return true + } + + until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute) + reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + slog.Warn( + "openai_403_temp_unschedulable", + "account_id", account.ID, + "until", until, + "count", count, + "threshold", openAI403DisableThreshold, + ) + return true +} + // handleAntigravity403 处理 Antigravity 平台的 403 错误 // validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) // violation(违规封号)→ 永久 SetError(需人工处理) @@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac switch fbType { case forbiddenTypeValidation: // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 - msg := "Validation required (403): account needs Google verification" - if upstreamMsg != "" { - msg = "Validation required (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Validation required (403):", + upstreamMsg, + responseBody, + "account needs Google verification", + ) if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { msg += " | validation_url: " + validationURL } @@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac case forbiddenTypeViolation: // 违规封号: 永久禁用,需人工处理 - msg := "Account violation (403): terms of service violation" - if upstreamMsg != "" { - msg = "Account violation (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Account violation (403):", + upstreamMsg, + responseBody, + "terms of service violation", + ) s.handleAuthError(ctx, account, msg) return true default: // 通用 403: 保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } @@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } + s.ResetOpenAI403Counter(ctx, accountID) return nil } +func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) { + if s == nil || s.openAI403CounterCache == nil || accountID <= 0 { + return + } + if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil { + slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err) + } +} + // RecoverAccountState 按需恢复账号的可恢复运行时状态。 func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { account, err := s.accountRepo.GetByID(ctx, accountID) @@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in } result.ClearedRateLimit = true } + if result.ClearedError || result.ClearedRateLimit { + s.ResetOpenAI403Counter(ctx, accountID) + } return result, nil } diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 9e5e2b0e8adbb09ce4432564e0c53b18249f3066..73b7849fbac594fa3a4d13bce723128a2ad86aa4 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct { updateCredentialsCalls int lastCredentials map[string]any lastErrorMsg string + lastTempReason string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ + r.lastTempReason = reason return nil } @@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct { err error } +type openAI403CounterCacheStub struct { + counts []int64 + resetCalls []int64 + err error +} + +func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) { + if s.err != nil { + return 0, s.err + } + if len(s.counts) == 0 { + return 1, nil + } + count := s.counts[0] + s.counts = s.counts[1:] + return count, nil +} + +func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { r.accounts = append(r.accounts, account) return r.err diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2fd11b7169087433e015fd3f6f19239bfe11019b --- /dev/null +++ b/backend/internal/service/ratelimit_service_403_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{1}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"temporary edge rejection"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Contains(t, repo.lastTempReason, "temporary edge rejection") + require.Contains(t, repo.lastTempReason, "(1/3)") +} + +func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{3}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 302, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3") +} diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 89c754c8550b587efbbf8a345a3fb6b0ff52b519..619bb7730d53dd3a77e7ee72d4c20de4acfaf269 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -7,6 +7,9 @@ import ( "net/http" "testing" "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" ) func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) { @@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) { } } +func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 201, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + +func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, `"access_denied"`) + require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`) + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) { // Test when only secondary has data, no window_minutes sUsed := 60.0 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 27a855eab9765705f74f900a303f9ccc80d90ae0..38686319d53f9f905d98a12dfead420cc19d6d9b 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -114,6 +114,253 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder } +type ProviderDefaultGrantSettings struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + GrantOnSignup bool + GrantOnFirstBind bool +} + +type AuthSourceDefaultSettings struct { + Email ProviderDefaultGrantSettings + LinuxDo ProviderDefaultGrantSettings + OIDC ProviderDefaultGrantSettings + WeChat ProviderDefaultGrantSettings + ForceEmailOnThirdPartySignup bool +} + +type authSourceDefaultKeySet struct { + balance string + concurrency string + subscriptions string + grantOnSignup string + grantOnFirstBind string +} + +var ( + emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultEmailBalance, + concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, + subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + } + linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultLinuxDoBalance, + concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, + subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + } + oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultOIDCBalance, + concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, + subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + } + weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultWeChatBalance, + concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, + subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + } +) + +const ( + defaultAuthSourceBalance = 0 + defaultAuthSourceConcurrency = 5 + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontend = "/auth/wechat/callback" +) + +func normalizeWeChatConnectModeSetting(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + case "mobile": + return "mobile" + default: + return "open" + } +} + +func defaultWeChatConnectScopeForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return "snsapi_userinfo" + case "mobile": + return "" + } + return defaultWeChatConnectScopes +} + +func normalizeWeChatConnectScopeSetting(raw, mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopeForMode(mode) + } + case "mobile": + return "" + default: + return defaultWeChatConnectScopes + } +} + +func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(mode) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + return openEnabled, mpEnabled, mobileEnabled + } + + if !enabled { + return false, false, false + } + if mode == "mp" { + return false, true, false + } + if mode == "mobile" { + return false, false, true + } + return true, false, false +} + +func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + mode = normalizeWeChatConnectModeSetting(mode) + switch mode { + case "open": + if openEnabled { + return "open" + } + case "mp": + if mpEnabled { + return "mp" + } + case "mobile": + if mobileEnabled { + return "mobile" + } + } + switch { + case openEnabled: + return "open" + case mpEnabled: + return "mp" + case mobileEnabled: + return "mobile" + default: + return mode + } +} + +func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode)) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled] + if !enabledConfigured && + enabled && + !openEnabled && + !mpEnabled && + !mobileEnabled && + (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return openEnabled, mpEnabled, mobileEnabled + } + if !enabled { + return false, false, false + } + if base.OpenEnabled || base.MPEnabled || base.MobileEnabled { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return parseWeChatConnectCapabilitySettings(settings, enabled, mode) +} + +func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig { + base := config.WeChatConnectConfig{} + if s != nil && s.cfg != nil { + base = s.cfg.WeChat + } + + enabled := base.Enabled + if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok { + enabled = strings.TrimSpace(raw) == "true" + } + + legacyAppID := strings.TrimSpace(firstNonEmpty( + settings[SettingKeyWeChatConnectAppID], + base.AppID, + base.OpenAppID, + base.MPAppID, + base.MobileAppID, + )) + legacyAppSecret := strings.TrimSpace(firstNonEmpty( + settings[SettingKeyWeChatConnectAppSecret], + base.AppSecret, + base.OpenAppSecret, + base.MPAppSecret, + base.MobileAppSecret, + )) + openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], base.OpenAppID, legacyAppID)) + openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], base.OpenAppSecret, legacyAppSecret)) + mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], base.MPAppID, legacyAppID)) + mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], base.MPAppSecret, legacyAppSecret)) + mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], base.MobileAppID, legacyAppID)) + mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], base.MobileAppSecret, legacyAppSecret)) + + modeRaw := firstNonEmpty(settings[SettingKeyWeChatConnectMode], base.Mode) + openEnabled, mpEnabled, mobileEnabled := mergeWeChatConnectCapabilitySettings(settings, base, enabled, modeRaw) + mode := normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, modeRaw) + + return WeChatConnectOAuthConfig{ + Enabled: enabled, + LegacyAppID: legacyAppID, + LegacyAppSecret: legacyAppSecret, + OpenAppID: openAppID, + OpenAppSecret: openAppSecret, + MPAppID: mpAppID, + MPAppSecret: mpAppSecret, + MobileAppID: mobileAppID, + MobileAppSecret: mobileAppSecret, + OpenEnabled: openEnabled, + MPEnabled: mpEnabled, + MobileEnabled: mobileEnabled, + Mode: mode, + Scopes: normalizeWeChatConnectScopeSetting(firstNonEmpty(settings[SettingKeyWeChatConnectScopes], base.Scopes), mode), + RedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectRedirectURL], base.RedirectURL)), + FrontendRedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectFrontendRedirectURL], base.FrontendRedirectURL, defaultWeChatConnectFrontend)), + } +} + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -156,6 +403,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, @@ -178,6 +426,22 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectOpenAppID, + SettingKeyWeChatConnectOpenAppSecret, + SettingKeyWeChatConnectMPAppID, + SettingKeyWeChatConnectMPAppSecret, + SettingKeyWeChatConnectMobileAppID, + SettingKeyWeChatConnectMobileAppSecret, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, SettingKeyBackendModeEnabled, SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, @@ -186,6 +450,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL, SettingKeyAccountQuotaNotifyEnabled, + SettingKeyChannelMonitorEnabled, + SettingKeyChannelMonitorDefaultIntervalSeconds, + SettingKeyAvailableChannelsEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -212,6 +479,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -232,6 +500,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, @@ -254,6 +523,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, + WeChatOAuthOpenEnabled: weChatOpenEnabled, + WeChatOAuthMPEnabled: weChatMPEnabled, + WeChatOAuthMobileEnabled: weChatMobileEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, @@ -262,9 +535,88 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true", BalanceLowNotifyThreshold: balanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL], + + ChannelMonitorEnabled: !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]), + ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]), + + AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", }, nil } +// channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval +// (mirrors the monitor-level constraint but lives here so setting_service stays decoupled). +const ( + channelMonitorIntervalMin = 15 + channelMonitorIntervalMax = 3600 + channelMonitorIntervalFallback = 60 +) + +// parseChannelMonitorInterval parses the stored string and clamps to [15, 3600]. +// Empty / invalid input falls back to channelMonitorIntervalFallback. +func parseChannelMonitorInterval(raw string) int { + v, err := strconv.Atoi(strings.TrimSpace(raw)) + if err != nil { + return channelMonitorIntervalFallback + } + return clampChannelMonitorInterval(v) +} + +// clampChannelMonitorInterval clamps v to the allowed range. 0 means "not provided". +func clampChannelMonitorInterval(v int) int { + if v <= 0 { + return 0 + } + if v < channelMonitorIntervalMin { + return channelMonitorIntervalMin + } + if v > channelMonitorIntervalMax { + return channelMonitorIntervalMax + } + return v +} + +// ChannelMonitorRuntime is the lightweight view of the channel monitor feature +// consumed by the runner and user-facing handlers. +type ChannelMonitorRuntime struct { + Enabled bool + DefaultIntervalSeconds int +} + +// GetChannelMonitorRuntime reads the channel monitor feature flags directly from +// the settings store. Fail-open: on error returns Enabled=true with the default interval. +func (s *SettingService) GetChannelMonitorRuntime(ctx context.Context) ChannelMonitorRuntime { + vals, err := s.settingRepo.GetMultiple(ctx, []string{ + SettingKeyChannelMonitorEnabled, + SettingKeyChannelMonitorDefaultIntervalSeconds, + }) + if err != nil { + return ChannelMonitorRuntime{Enabled: true, DefaultIntervalSeconds: channelMonitorIntervalFallback} + } + return ChannelMonitorRuntime{ + Enabled: !isFalseSettingValue(vals[SettingKeyChannelMonitorEnabled]), + DefaultIntervalSeconds: parseChannelMonitorInterval(vals[SettingKeyChannelMonitorDefaultIntervalSeconds]), + } +} + +// AvailableChannelsRuntime is the lightweight view of the available-channels feature +// switch consumed by the user-facing handler. +type AvailableChannelsRuntime struct { + Enabled bool +} + +// GetAvailableChannelsRuntime reads the available-channels feature switch directly +// from the settings store. Fail-closed: on error returns Enabled=false, matching +// the opt-in default (unknown ↔ disabled). +func (s *SettingService) GetAvailableChannelsRuntime(ctx context.Context) AvailableChannelsRuntime { + vals, err := s.settingRepo.GetMultiple(ctx, []string{SettingKeyAvailableChannelsEnabled}) + if err != nil { + return AvailableChannelsRuntime{Enabled: false} + } + return AvailableChannelsRuntime{ + Enabled: vals[SettingKeyAvailableChannelsEnabled] == "true", + } +} + // SetOnUpdateCallback sets a callback function to be called when settings are updated // This is used for cache invalidation (e.g., HTML cache in frontend server) func (s *SettingService) SetOnUpdateCallback(callback func()) { @@ -276,50 +628,75 @@ func (s *SettingService) SetVersion(version string) { s.version = version } -// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection -// This implements the web.PublicSettingsProvider interface +// PublicSettingsInjectionPayload is the JSON shape embedded into HTML as +// `window.__APP_CONFIG__` so the frontend can hydrate feature flags & site +// config before the first XHR finishes. +// +// INVARIANT: every `json` tag here MUST also exist on handler/dto.PublicSettings. +// If you forget a feature-flag field here, the frontend's +// `cachedPublicSettings.xxx_enabled` will be `undefined` on refresh until the +// async `/api/v1/settings/public` call returns — which causes opt-in menus +// (strict `=== true`) to flicker off/on. See +// frontend/src/utils/featureFlags.ts for the matching registry. +// +// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch +// drift automatically (see setting_service_injection_test.go). +type PublicSettingsInjectionPayload struct { + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + CustomEndpoints json.RawMessage `json:"custom_endpoints"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` + WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` + WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` + WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"` + OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` + OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` + BackendModeEnabled bool `json:"backend_mode_enabled"` + PaymentEnabled bool `json:"payment_enabled"` + Version string `json:"version"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + + // Feature flags — MUST match the opt-in/opt-out registry in + // frontend/src/utils/featureFlags.ts. Missing a field here is the bug + // that hid the "可用渠道" menu on page refresh. + ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` + AvailableChannelsEnabled bool `json:"available_channels_enabled"` +} + +// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection. +// This implements the web.PublicSettingsProvider interface. func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) { settings, err := s.GetPublicSettings(ctx) if err != nil { return nil, err } - // Return a struct that matches the frontend's expected format - return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - TableDefaultPageSize int `json:"table_default_page_size"` - TablePageSizeOptions []int `json:"table_page_size_options"` - CustomMenuItems json.RawMessage `json:"custom_menu_items"` - CustomEndpoints json.RawMessage `json:"custom_endpoints"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - BackendModeEnabled bool `json:"backend_mode_enabled"` - PaymentEnabled bool `json:"payment_enabled"` - OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` - OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` - Version string `json:"version,omitempty"` - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` - }{ + return &PublicSettingsInjectionPayload{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, @@ -344,18 +721,84 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - BackendModeEnabled: settings.BackendModeEnabled, - PaymentEnabled: settings.PaymentEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, + WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, + WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled, + WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, + BackendModeEnabled: settings.BackendModeEnabled, + PaymentEnabled: settings.PaymentEnabled, Version: s.version, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + + ChannelMonitorEnabled: settings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, + AvailableChannelsEnabled: settings.AvailableChannelsEnabled, }, nil } +func DefaultWeChatConnectScopesForMode(mode string) string { + return defaultWeChatConnectScopeForMode(mode) +} + +func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { + cfg := s.effectiveWeChatConnectOAuthConfig(settings) + + if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) { + return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if cfg.OpenEnabled { + if cfg.AppIDForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured") + } + if cfg.AppSecretForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured") + } + } + if cfg.MPEnabled { + if cfg.AppIDForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured") + } + if cfg.AppSecretForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured") + } + } + if cfg.MobileEnabled { + if cfg.AppIDForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured") + } + if cfg.AppSecretForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured") + } + } + if v := strings.TrimSpace(cfg.RedirectURL); v != "" { + if err := config.ValidateAbsoluteHTTPURL(v); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") + } + return cfg, nil +} + +func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) { + cfg := s.effectiveWeChatConnectOAuthConfig(settings) + if !cfg.Enabled { + return false, false, false, false + } + + openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != "" + mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != "" + mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != "" + + return openReady || mpReady, openReady, mpReady, mobileReady +} + // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -478,19 +921,130 @@ func parseCustomMenuItemURLs(raw string) []string { return urls } +func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.UsePKCEExplicit { + return base.UsePKCE + } + return true +} + +func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.ValidateIDTokenExplicit { + return base.ValidateIDToken + } + return true +} + +func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool { + if configured { + return strings.TrimSpace(raw) == "true" + } + if explicit { + return explicitValue + } + return false +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { - if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) { + rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{ + SettingKeyOIDCConnectUsePKCE, + SettingKeyOIDCConnectValidateIDToken, + }) + if err != nil { + return false, false, fmt.Errorf("get oidc security write defaults: %w", err) + } + + base := config.OIDCConnectConfig{} + if s != nil && s.cfg != nil { + base = s.cfg.OIDC + } + + rawUsePKCE, hasUsePKCE := rawSettings[SettingKeyOIDCConnectUsePKCE] + rawValidateIDToken, hasValidateIDToken := rawSettings[SettingKeyOIDCConnectValidateIDToken] + + return oidcCompatibilityWriteDefault(base, hasUsePKCE, rawUsePKCE, base.UsePKCEExplicit, base.UsePKCE), + oidcCompatibilityWriteDefault(base, hasValidateIDToken, rawValidateIDToken, base.ValidateIDTokenExplicit, base.ValidateIDToken), + nil +} + +// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. +func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { return err } + + authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults) + if err != nil { + return err + } + for key, value := range authSourceUpdates { + updates[key] = value + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return nil, err + } normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) } if normalizedWhitelist == nil { normalizedWhitelist = []string{} } settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) + if err != nil { + return nil, err + } + wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) + if err != nil { + return nil, err + } + settings.PaymentVisibleMethodAlipaySource = alipaySource + settings.PaymentVisibleMethodWxpaySource = wxpaySource + settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) + settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) + settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMode = normalizeWeChatConnectStoredMode( + settings.WeChatConnectOpenEnabled, + settings.WeChatConnectMPEnabled, + settings.WeChatConnectMobileEnabled, + settings.WeChatConnectMode, + ) + settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) + settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) + settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) + if settings.WeChatConnectFrontendRedirectURL == "" { + settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } updates := make(map[string]string) @@ -499,7 +1053,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) if err != nil { - return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err) } updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) @@ -560,6 +1114,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret } + // WeChat Connect OAuth 登录 + updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) + updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID + updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID + updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID + updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID + updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled) + updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled) + updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled) + updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode + updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes + updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL + updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL + if settings.WeChatConnectAppSecret != "" { + updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret + } + if settings.WeChatConnectOpenAppSecret != "" { + updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret + } + if settings.WeChatConnectMPAppSecret != "" { + updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret + } + if settings.WeChatConnectMobileAppSecret != "" { + updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -578,7 +1158,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) if err != nil { - return fmt.Errorf("marshal table page size options: %w", err) + return nil, fmt.Errorf("marshal table page size options: %w", err) } updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems @@ -587,9 +1167,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { - return fmt.Errorf("marshal default subscriptions: %w", err) + return nil, fmt.Errorf("marshal default subscriptions: %w", err) } updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) @@ -612,6 +1193,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) } + // Channel monitor feature switch + updates[SettingKeyChannelMonitorEnabled] = strconv.FormatBool(settings.ChannelMonitorEnabled) + if v := clampChannelMonitorInterval(settings.ChannelMonitorDefaultIntervalSeconds); v > 0 { + updates[SettingKeyChannelMonitorDefaultIntervalSeconds] = strconv.Itoa(v) + } + + // Available channels feature switch + updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled) + // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion @@ -626,6 +1216,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource + updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource + updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) + updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) + updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) @@ -634,32 +1229,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - versionBoundsSF.Forget("version_bounds") - versionBoundsCache.Store(&cachedVersionBounds{ - min: settings.MinClaudeCodeVersion, - max: settings.MaxClaudeCodeVersion, - expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), - }) - backendModeSF.Forget("backend_mode") - backendModeCache.Store(&cachedBackendMode{ - value: settings.BackendModeEnabled, - expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), - }) - gatewayForwardingSF.Forget("gateway_forwarding") - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), - }) - if s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update + return updates, nil +} + +func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { + if settings == nil { + return nil, nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return nil, err } } - return err + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + return updates, nil +} + +func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { + if settings == nil { + return + } + + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + versionBoundsSF.Forget("version_bounds") + versionBoundsCache.Store(&cachedVersionBounds{ + min: settings.MinClaudeCodeVersion, + max: settings.MaxClaudeCodeVersion, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), + }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + gatewayForwardingSF.Forget("gateway_forwarding") + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } } func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { @@ -910,6 +1539,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。 +func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit) + if err != nil || value == "" { + return 0 + } + if v, err := strconv.Atoi(value); err == nil && v >= 0 { + return v + } + return 0 +} + // GetDefaultSubscriptions 获取新用户默认订阅配置列表。 func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) @@ -919,6 +1560,88 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS return parseDefaultSubscriptions(value) } +func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) { + keys := []string{ + SettingKeyAuthSourceDefaultEmailBalance, + SettingKeyAuthSourceDefaultEmailConcurrency, + SettingKeyAuthSourceDefaultEmailSubscriptions, + SettingKeyAuthSourceDefaultEmailGrantOnSignup, + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + SettingKeyAuthSourceDefaultLinuxDoBalance, + SettingKeyAuthSourceDefaultLinuxDoConcurrency, + SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + SettingKeyAuthSourceDefaultOIDCBalance, + SettingKeyAuthSourceDefaultOIDCConcurrency, + SettingKeyAuthSourceDefaultOIDCSubscriptions, + SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + SettingKeyAuthSourceDefaultWeChatBalance, + SettingKeyAuthSourceDefaultWeChatConcurrency, + SettingKeyAuthSourceDefaultWeChatSubscriptions, + SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyForceEmailOnThirdPartySignup, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get auth source default settings: %w", err) + } + + return &AuthSourceDefaultSettings{ + Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys), + LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), + OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), + WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + }, nil +} + +func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) { + result := ProviderDefaultGrantSettings{ + Balance: s.GetDefaultBalance(ctx), + Concurrency: s.GetDefaultConcurrency(ctx), + Subscriptions: s.GetDefaultSubscriptions(ctx), + } + + defaults, err := s.GetAuthSourceDefaultSettings(ctx) + if err != nil { + return result, false, err + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok { + return result, false, nil + } + + enabled := providerDefaults.GrantOnSignup + if firstBind { + enabled = providerDefaults.GrantOnFirstBind + } + if !enabled { + return result, false, nil + } + + return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil +} + +func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { + updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings) + if err != nil { + return err + } + if len(updates) == 0 { + return nil + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return fmt.Errorf("update auth source default settings: %w", err) + } + return nil +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -931,27 +1654,96 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { return fmt.Errorf("check existing settings: %w", err) } + oidcUsePKCEDefault := true + oidcValidateIDTokenDefault := true + if s != nil && s.cfg != nil { + if s.cfg.OIDC.UsePKCEExplicit { + oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE + } + if s.cfg.OIDC.ValidateIDTokenExplicit { + oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken + } + } + // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "TrafficAPI", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "TrafficAPI", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectAppID: "", + SettingKeyWeChatConnectAppSecret: "", + SettingKeyWeChatConnectOpenAppID: "", + SettingKeyWeChatConnectOpenAppSecret: "", + SettingKeyWeChatConnectMPAppID: "", + SettingKeyWeChatConnectMPAppSecret: "", + SettingKeyWeChatConnectMobileAppID: "", + SettingKeyWeChatConnectMobileAppSecret: "", + SettingKeyWeChatConnectOpenEnabled: "false", + SettingKeyWeChatConnectMPEnabled: "false", + SettingKeyWeChatConnectMobileEnabled: "false", + SettingKeyWeChatConnectMode: "open", + SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectRedirectURL: "", + SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyOIDCConnectClientID: "", + SettingKeyOIDCConnectClientSecret: "", + SettingKeyOIDCConnectIssuerURL: "", + SettingKeyOIDCConnectDiscoveryURL: "", + SettingKeyOIDCConnectAuthorizeURL: "", + SettingKeyOIDCConnectTokenURL: "", + SettingKeyOIDCConnectUserInfoURL: "", + SettingKeyOIDCConnectJWKSURL: "", + SettingKeyOIDCConnectScopes: "openid email profile", + SettingKeyOIDCConnectRedirectURL: "", + SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault), + SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault), + SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", + SettingKeyOIDCConnectClockSkewSeconds: "120", + SettingKeyOIDCConnectRequireEmailVerified: "false", + SettingKeyOIDCConnectUserInfoEmailPath: "", + SettingKeyOIDCConnectUserInfoIDPath: "", + SettingKeyOIDCConnectUserInfoUsernamePath: "", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultUserRPMLimit: "0", + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -968,12 +1760,24 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOpsQueryModeDefault: "auto", SettingKeyOpsMetricsIntervalSeconds: "60", + // Channel monitor defaults (enabled, 60s) + SettingKeyChannelMonitorEnabled: "true", + SettingKeyChannelMonitorDefaultIntervalSeconds: "60", + + // Available channels feature (default disabled; opt-in) + SettingKeyAvailableChannelsEnabled: "false", + // Claude Code version check (default: empty = disabled) SettingKeyMinClaudeCodeVersion: "", SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -1032,6 +1836,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.DefaultConcurrency = s.cfg.Default.UserConcurrency } + if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 { + result.DefaultUserRPMLimit = rpm + } + // 解析浮点数类型 if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { result.DefaultBalance = balance @@ -1157,12 +1965,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { result.OIDCConnectUsePKCE = raw == "true" } else { - result.OIDCConnectUsePKCE = oidcBase.UsePKCE + result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { result.OIDCConnectValidateIDToken = raw == "true" } else { - result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken + result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) @@ -1208,6 +2016,31 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + // WeChat Connect 设置: + // - 优先读取 DB 系统设置 + // - 缺失时回退到 config/env,保持升级兼容 + weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings) + result.WeChatConnectEnabled = weChatEffective.Enabled + result.WeChatConnectAppID = weChatEffective.LegacyAppID + result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret + result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != "" + result.WeChatConnectOpenAppID = weChatEffective.OpenAppID + result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret + result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != "" + result.WeChatConnectMPAppID = weChatEffective.MPAppID + result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret + result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != "" + result.WeChatConnectMobileAppID = weChatEffective.MobileAppID + result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret + result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != "" + result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled + result.WeChatConnectMPEnabled = weChatEffective.MPEnabled + result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled + result.WeChatConnectMode = weChatEffective.Mode + result.WeChatConnectScopes = weChatEffective.Scopes + result.WeChatConnectRedirectURL = weChatEffective.RedirectURL + result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -1240,6 +2073,15 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } } + // Channel monitor feature (default: enabled, 60s) + result.ChannelMonitorEnabled = !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]) + result.ChannelMonitorDefaultIntervalSeconds = parseChannelMonitorInterval( + settings[SettingKeyChannelMonitorDefaultIntervalSeconds], + ) + + // Available channels feature (default: disabled; strict true) + result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true" + // Claude Code version check result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] @@ -1263,6 +2105,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 } } + result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) + result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) + result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" + result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" + result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" @@ -1292,6 +2139,23 @@ func isFalseSettingValue(value string) bool { } } +func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) { + _ = enabled + source = strings.TrimSpace(source) + if source == "" { + return "", nil + } + + normalized := NormalizeVisibleMethodSource(method, source) + if normalized == "" { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return normalized, nil +} + func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { raw = strings.TrimSpace(raw) if raw == "" { @@ -1317,6 +2181,73 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: false, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + +func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: globalDefaults.Balance, + Concurrency: globalDefaults.Concurrency, + Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...), + GrantOnSignup: providerDefaults.GrantOnSignup, + GrantOnFirstBind: providerDefaults.GrantOnFirstBind, + } + + if providerDefaults.Balance != defaultAuthSourceBalance { + result.Balance = providerDefaults.Balance + } + if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { + result.Concurrency = providerDefaults.Concurrency + } + if len(providerDefaults.Subscriptions) > 0 { + result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...) + } + + return result +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { @@ -1539,7 +2470,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } - if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") } @@ -1587,9 +2517,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } @@ -1597,6 +2524,35 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。 +// +// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。 +func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) { + keys := []string{ + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectOpenAppID, + SettingKeyWeChatConnectOpenAppSecret, + SettingKeyWeChatConnectMPAppID, + SettingKeyWeChatConnectMPAppSecret, + SettingKeyWeChatConnectMobileAppID, + SettingKeyWeChatConnectMobileAppSecret, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err) + } + return s.parseWeChatConnectOAuthConfig(settings) +} + // GetOverloadCooldownSettings 获取529过载冷却配置 func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) @@ -1733,9 +2689,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. } if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { effective.UsePKCE = raw == "true" + } else { + effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective) } if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" + } else { + effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective) } if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) @@ -1864,9 +2824,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff4974066c5c18589235e20194cbcb39f9a3153 --- /dev/null +++ b/backend/internal/service/setting_service_auth_source_defaults_test.go @@ -0,0 +1,138 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type authSourceDefaultsRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for key, value := range settings { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) { + repo := &authSourceDefaultsRepoStub{ + values: map[string]string{ + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true", + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetAuthSourceDefaultSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 12.5, got.Email.Balance) + require.Equal(t, 7, got.Email.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions) + require.False(t, got.Email.GrantOnSignup) + require.False(t, got.Email.GrantOnFirstBind) + require.Equal(t, 0.0, got.LinuxDo.Balance) + require.Equal(t, 5, got.LinuxDo.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions) + require.False(t, got.LinuxDo.GrantOnSignup) + require.True(t, got.LinuxDo.GrantOnFirstBind) + require.Equal(t, 5, got.OIDC.Concurrency) + require.Equal(t, 5, got.WeChat.Concurrency) + require.False(t, got.OIDC.GrantOnSignup) + require.False(t, got.WeChat.GrantOnSignup) + require.True(t, got.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) { + repo := &authSourceDefaultsRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + Balance: 1.25, + Concurrency: 3, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + LinuxDo: ProviderDefaultGrantSettings{ + Balance: 2, + Concurrency: 4, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}}, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + OIDC: ProviderDefaultGrantSettings{ + Balance: 3, + Concurrency: 5, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}}, + GrantOnSignup: true, + GrantOnFirstBind: true, + }, + WeChat: ProviderDefaultGrantSettings{ + Balance: 4, + Concurrency: 6, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, + GrantOnSignup: false, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: true, + }) + require.NoError(t, err) + require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup]) + require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind]) + require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup]) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got)) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got) +} diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go index 3809b332bd107e2c6559a9b657014a8f6db80b77..6132420431255b51e7d86d4eea79dcecffad7570 100644 --- a/backend/internal/service/setting_service_oidc_config_test.go +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -101,3 +101,151 @@ func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testi require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL) require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL) } + +func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{}) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + SettingKeyOIDCConnectUsePKCE: "false", + SettingKeyOIDCConnectValidateIDToken: "false", + }) + + require.False(t, got.OIDCConnectUsePKCE) + require.False(t, got.OIDCConnectValidateIDToken) +} + +func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ + OIDC: config.OIDCConnectConfig{ + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, + }, + }) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }) + + require.True(t, got.OIDCConnectUsePKCE) + require.True(t, got.OIDCConnectValidateIDToken) +} + +func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) { + svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ + OIDC: config.OIDCConnectConfig{ + UsePKCE: true, + ValidateIDToken: true, + }, + }) + + got := svc.parseSettings(map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }) + + require.True(t, got.OIDCConnectUsePKCE) + require.True(t, got.OIDCConnectValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + SettingKeyOIDCConnectUsePKCE: "false", + SettingKeyOIDCConnectValidateIDToken: "false", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.False(t, got.UsePKCE) + require.False(t, got.ValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + UsePKCEExplicit: true, + ValidateIDToken: true, + ValidateIDTokenExplicit: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) +} + +func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) { + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: "https://issuer.example.com", + AuthorizeURL: "https://issuer.example.com/auth", + TokenURL: "https://issuer.example.com/token", + UserInfoURL: "https://issuer.example.com/userinfo", + JWKSURL: "https://issuer.example.com/jwks", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{ + SettingKeyOIDCConnectEnabled: "true", + }} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.UsePKCE) + require.True(t, got.ValidateIDToken) +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 5cf1e860eeab69786a119d9745b95b6cff59c214..1ecd4e6f416482b87e4fb7e0535079ee04e167f3 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -77,3 +77,77 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) require.Equal(t, 50, settings.TableDefaultPageSize) require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) } + +func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectOpenEnabled: "true", + SettingKeyWeChatConnectMPEnabled: "true", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.WeChatOAuthEnabled) + require.True(t, settings.WeChatOAuthOpenEnabled) + require.True(t, settings.WeChatOAuthMPEnabled) +} + +func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAuthAvailable(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectMobileEnabled: "true", + SettingKeyWeChatConnectMode: "mobile", + SettingKeyWeChatConnectMobileAppID: "wx-mobile-app", + SettingKeyWeChatConnectMobileAppSecret: "wx-mobile-secret", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.WeChatOAuthEnabled) + require.False(t, settings.WeChatOAuthOpenEnabled) + require.False(t, settings.WeChatOAuthMPEnabled) + require.True(t, settings.WeChatOAuthMobileEnabled) +} + +func TestSettingService_GetPublicSettings_FallsBackToConfigForWeChatOAuthCapabilities(t *testing.T) { + svc := NewSettingService(&settingPublicRepoStub{values: map[string]string{}}, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.WeChatOAuthEnabled) + require.True(t, settings.WeChatOAuthOpenEnabled) + require.False(t, settings.WeChatOAuthMPEnabled) + require.False(t, settings.WeChatOAuthMobileEnabled) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index e62218b4543d22487078bd2119198ece8bd7e5cc..9dc0ca59a3ff1d1b5bf99e9905bed1b576e35826 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize]) require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions]) } + +func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "alipay", + PaymentVisibleMethodWxpaySource: "easypay", + PaymentVisibleMethodAlipayEnabled: true, + PaymentVisibleMethodWxpayEnabled: false, + OpenAIAdvancedSchedulerEnabled: true, + }) + require.NoError(t, err) + require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey]) +} + +func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "not-a-provider", + }) + require.Error(t, err) + require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a2de614b1e876f620d7403dbb9b8bc74467b5ba0 --- /dev/null +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -0,0 +1,162 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingWeChatRepoStub struct { + values map[string]string +} + +func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) { + if value, ok := s.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (s *settingWeChatRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingWeChatRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-db-app", + SettingKeyWeChatConnectAppSecret: "wx-db-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectOpenEnabled: "true", + SettingKeyWeChatConnectMPEnabled: "true", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.Equal(t, "wx-db-app", got.AppIDForMode("mp")) + require.Equal(t, "wx-db-secret", got.AppSecretForMode("mp")) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "mp", got.Mode) + require.Equal(t, "snsapi_base", got.Scopes) + require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL) + require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL) +} + +func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabaseEmpty(t *testing.T) { + repo := &settingWeChatRepoStub{values: map[string]string{}} + svc := NewSettingService(repo, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + MPEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + MPAppID: "wx-mp-config", + MPAppSecret: "wx-mp-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "wx-open-config", got.AppIDForMode("open")) + require.Equal(t, "wx-open-secret", got.AppSecretForMode("open")) + require.Equal(t, "wx-mp-config", got.AppIDForMode("mp")) + require.Equal(t, "wx-mp-secret", got.AppSecretForMode("mp")) + require.Equal(t, "/auth/wechat/config-callback", got.FrontendRedirectURL) + require.Empty(t, got.RedirectURL) +} + +func TestSettingService_GetWeChatConnectOAuthConfig_IgnoresSyntheticDisabledCapabilitiesFromMigration118(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectOpenEnabled: "false", + SettingKeyWeChatConnectMPEnabled: "false", + }, + } + svc := NewSettingService(repo, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + MPEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + MPAppID: "wx-mp-config", + MPAppSecret: "wx-mp-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.True(t, got.OpenEnabled) + require.True(t, got.MPEnabled) + require.Equal(t, "wx-open-config", got.AppIDForMode("open")) + require.Equal(t, "wx-mp-config", got.AppIDForMode("mp")) +} + +func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) { + svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{ + WeChat: config.WeChatConnectConfig{ + Enabled: true, + OpenEnabled: true, + Mode: "open", + OpenAppID: "wx-open-config", + OpenAppSecret: "wx-open-secret", + FrontendRedirectURL: "/auth/wechat/config-callback", + }, + }) + + got := svc.parseSettings(map[string]string{}) + require.True(t, got.WeChatConnectEnabled) + require.True(t, got.WeChatConnectOpenEnabled) + require.Equal(t, "wx-open-config", got.WeChatConnectOpenAppID) + require.True(t, got.WeChatConnectOpenAppSecretConfigured) + require.Equal(t, "/auth/wechat/config-callback", got.WeChatConnectFrontendRedirectURL) + require.Equal(t, "open", got.WeChatConnectMode) + require.Equal(t, "snsapi_login", got.WeChatConnectScopes) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ab2eb274fd95fbcb66291c6a1d3be1c13670a2a1..ddd4fff66628cd9f17092343a11066ac02b8c952 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,5 +1,16 @@ package service +import "strings" + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + type SystemSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool @@ -31,6 +42,28 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool + WeChatConnectAppID string + WeChatConnectAppSecret string + WeChatConnectAppSecretConfigured bool + WeChatConnectOpenAppID string + WeChatConnectOpenAppSecret string + WeChatConnectOpenAppSecretConfigured bool + WeChatConnectMPAppID string + WeChatConnectMPAppSecret string + WeChatConnectMPAppSecretConfigured bool + WeChatConnectMobileAppID string + WeChatConnectMobileAppSecret string + WeChatConnectMobileAppSecretConfigured bool + WeChatConnectOpenEnabled bool + WeChatConnectMPEnabled bool + WeChatConnectMobileEnabled bool + WeChatConnectMode string + WeChatConnectScopes string + WeChatConnectRedirectURL string + WeChatConnectFrontendRedirectURL string + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool OIDCConnectProviderName string @@ -73,6 +106,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + DefaultUserRPMLimit int DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration @@ -92,6 +126,13 @@ type SystemSettings struct { OpsQueryModeDefault string OpsMetricsIntervalSeconds int + // Channel Monitor feature + ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` + + // Available Channels feature (user-facing aggregate view) + AvailableChannelsEnabled bool `json:"available_channels_enabled"` + // Claude Code version check MinClaudeCodeVersion string MaxClaudeCodeVersion string @@ -110,6 +151,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Payment visible method routing + PaymentVisibleMethodAlipaySource string + PaymentVisibleMethodWxpaySource string + PaymentVisibleMethodAlipayEnabled bool + PaymentVisibleMethodWxpayEnabled bool + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 @@ -128,6 +178,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool + ForceEmailOnThirdPartySignup bool RegistrationEmailSuffixWhitelist []string PromoCodeEnabled bool PasswordResetEnabled bool @@ -151,17 +202,88 @@ type PublicSettings struct { CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints - LinuxDoOAuthEnabled bool - BackendModeEnabled bool - PaymentEnabled bool - OIDCOAuthEnabled bool - OIDCOAuthProviderName string - Version string + LinuxDoOAuthEnabled bool + WeChatOAuthEnabled bool + WeChatOAuthOpenEnabled bool + WeChatOAuthMPEnabled bool + WeChatOAuthMobileEnabled bool + BackendModeEnabled bool + PaymentEnabled bool + OIDCOAuthEnabled bool + OIDCOAuthProviderName string + Version string BalanceLowNotifyEnabled bool AccountQuotaNotifyEnabled bool BalanceLowNotifyThreshold float64 BalanceLowNotifyRechargeURL string + + // Channel Monitor feature + ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` + ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` + + // Available Channels feature (user-facing aggregate view) + AvailableChannelsEnabled bool `json:"available_channels_enabled"` +} + +type WeChatConnectOAuthConfig struct { + Enabled bool + LegacyAppID string + LegacyAppSecret string + OpenAppID string + OpenAppSecret string + MPAppID string + MPAppSecret string + MobileAppID string + MobileAppSecret string + OpenEnabled bool + MPEnabled bool + MobileEnabled bool + Mode string + Scopes string + RedirectURL string + FrontendRedirectURL string +} + +func (cfg WeChatConnectOAuthConfig) SupportsMode(mode string) bool { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return cfg.MPEnabled + case "mobile": + return cfg.MobileEnabled + default: + return cfg.OpenEnabled + } +} + +func (cfg WeChatConnectOAuthConfig) ScopeForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return normalizeWeChatConnectScopeSetting(cfg.Scopes, "mp") + case "mobile": + return "" + } + return defaultWeChatConnectScopeForMode("open") +} + +func (cfg WeChatConnectOAuthConfig) AppIDForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return strings.TrimSpace(firstNonEmpty(cfg.MPAppID, cfg.LegacyAppID)) + case "mobile": + return strings.TrimSpace(firstNonEmpty(cfg.MobileAppID, cfg.LegacyAppID)) + } + return strings.TrimSpace(firstNonEmpty(cfg.OpenAppID, cfg.LegacyAppID)) +} + +func (cfg WeChatConnectOAuthConfig) AppSecretForMode(mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + return strings.TrimSpace(firstNonEmpty(cfg.MPAppSecret, cfg.LegacyAppSecret)) + case "mobile": + return strings.TrimSpace(firstNonEmpty(cfg.MobileAppSecret, cfg.LegacyAppSecret)) + } + return strings.TrimSpace(firstNonEmpty(cfg.OpenAppSecret, cfg.LegacyAppSecret)) } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/backend/internal/service/sql_errors.go b/backend/internal/service/sql_errors.go new file mode 100644 index 0000000000000000000000000000000000000000..7c0155a4e68044742ddb1084f570f9dde33484bc --- /dev/null +++ b/backend/internal/service/sql_errors.go @@ -0,0 +1,14 @@ +package service + +import ( + "database/sql" + "errors" + "strings" +) + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") +} diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index e7ef89826c399e5f9d2b3d41d491f0deb69ae462..11ace7bd8e0e2f1315e7d42ffbed875fb1b6157f 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -15,20 +15,8 @@ import ( "github.com/stretchr/testify/require" ) -// TestShouldClearStickySession 测试粘性会话清理判断逻辑。 -// 验证在以下情况下是否正确判断需要清理粘性会话: -// - nil 账号:不清理(返回 false) -// - 状态为错误或禁用:清理 -// - 不可调度:清理 -// - 临时不可调度且未过期:清理 -// - 临时不可调度已过期:不清理 -// - 正常可调度状态:不清理 -// - 模型限流(任意时长):清理 -// -// TestShouldClearStickySession tests the sticky session clearing logic. -// Verifies correct behavior for various account states including: -// nil account, error/disabled status, unschedulable, temporary unschedulable, -// and model rate limiting scenarios. +// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation +// plus model-level rate limiting. func TestShouldClearStickySession(t *testing.T) { now := time.Now() future := now.Add(1 * time.Hour) @@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) { requestedModel: "claude-opus-4", // 请求不同模型 want: false, // 不同模型不受影响 }, + { + name: "apikey quota exceeded", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 10.0, + "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + }, + requestedModel: "", + want: true, + }, + { + name: "oauth quota exceeded not cleared", + account: &Account{ + Status: StatusActive, + Schedulable: true, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "quota_daily_limit": 10.0, + "quota_daily_used": 10.0, + "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339), + }, + }, + requestedModel: "", + want: false, + }, + { + name: "overloaded account", + account: &Account{ + Status: StatusActive, + Schedulable: true, + OverloadUntil: &future, + }, + requestedModel: "", + want: true, + }, + { + name: "account-level rate limited", + account: &Account{ + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + }, + requestedModel: "", + want: true, + }, } for _, tt := range tests { diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go index 6a0da5f96011090989a5ce8c3bb76ebb8aa3ff9a..2e57c0ed8cd301d61d7a5362aa48bb8da83247b6 100644 --- a/backend/internal/service/totp_service.go +++ b/backend/internal/service/totp_service.go @@ -58,9 +58,15 @@ type TotpSetupSession struct { // TotpLoginSession represents a pending 2FA login session type TotpLoginSession struct { - UserID int64 - Email string - TokenExpiry time.Time + UserID int64 + Email string + TokenExpiry time.Time + PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"` +} + +type PendingOAuthBindLoginSession struct { + PendingSessionToken string `json:"pending_session_token,omitempty"` + BrowserSessionKey string `json:"browser_session_key,omitempty"` } // TotpStatus represents the TOTP status for a user @@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) // CreateLoginSession creates a temporary login session for 2FA func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) { + return s.createLoginSession(ctx, userID, email, nil) +} + +// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will +// finalize a pending OAuth bind after the TOTP code is verified. +func (s *TotpService) CreatePendingOAuthBindLoginSession( + ctx context.Context, + userID int64, + email string, + pendingSessionToken string, + browserSessionKey string, +) (string, error) { + return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{ + PendingSessionToken: pendingSessionToken, + BrowserSessionKey: browserSessionKey, + }) +} + +func (s *TotpService) createLoginSession( + ctx context.Context, + userID int64, + email string, + pendingOAuthBind *PendingOAuthBindLoginSession, +) (string, error) { // Generate a random temp token tempToken, err := generateRandomToken(32) if err != nil { @@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai } session := &TotpLoginSession{ - UserID: userID, - Email: email, - TokenExpiry: time.Now().Add(totpLoginTTL), + UserID: userID, + Email: email, + TokenExpiry: time.Now().Add(totpLoginTTL), + PendingOAuthBind: pendingOAuthBind, } if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil { diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index a0444d52092c9d6d529b49b0388344797dd130f3..ddf0e8185241ddf5597dd348cf64bf1a0d275980 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -12,7 +12,9 @@ import ( var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") -const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 +// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes, +// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。 +const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 59f8aa6b78eb4433a877f6b5490c0b51d28de9f4..f98336111bacca30f74b283e7d39991f0e20b494 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -7,19 +7,31 @@ import ( ) type User struct { - ID int64 - Email string - Username string - Notes string - PasswordHash string - Role string - Balance float64 - Concurrency int - Status string - AllowedGroups []int64 - TokenVersion int64 // Incremented on password change to invalidate existing tokens - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Email string + Username string + Notes string + AvatarURL string + AvatarSource string + AvatarMIME string + AvatarByteSize int + AvatarSHA256 string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + TokenVersion int64 // Incremented on password change to invalidate existing tokens + // TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived + // value expected in JWT claims and refresh-token state. + TokenVersionResolved bool + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier @@ -37,6 +49,15 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry TotalRecharged float64 + // RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit + // 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。 + RPMLimit int + + // UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。 + // nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用, + // 避免每请求查 DB。字段不持久化到数据库。 + UserGroupRPMOverride *int + APIKeys []APIKey Subscriptions []UserSubscription } diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go index 3d221a25e52b17c0f26903cad601958ef7f0513d..f069eb7e41504580894ec0246bcf490f5b7a0fa1 100644 --- a/backend/internal/service/user_group_rate.go +++ b/backend/internal/service/user_group_rate.go @@ -2,14 +2,16 @@ package service import "context" -// UserGroupRateEntry 分组下用户专属倍率条目 +// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。 +// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。 type UserGroupRateEntry struct { - UserID int64 `json:"user_id"` - UserName string `json:"user_name"` - UserEmail string `json:"user_email"` - UserNotes string `json:"user_notes"` - UserStatus string `json:"user_status"` - RateMultiplier float64 `json:"rate_multiplier"` + UserID int64 `json:"user_id"` + UserName string `json:"user_name"` + UserEmail string `json:"user_email"` + UserNotes string `json:"user_notes"` + UserStatus string `json:"user_status"` + RateMultiplier *float64 `json:"rate_multiplier,omitempty"` + RPMOverride *int `json:"rpm_override,omitempty"` } // GroupRateMultiplierInput 批量设置分组倍率的输入条目 @@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct { RateMultiplier float64 `json:"rate_multiplier"` } -// UserGroupRateRepository 用户专属分组倍率仓储接口 -// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。 +// RPMOverride 为 *int 以支持清除(nil)语义。 +type GroupRPMOverrideInput struct { + UserID int64 `json:"user_id"` + RPMOverride *int `json:"rpm_override"` +} + +// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。 +// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。 type UserGroupRateRepository interface { - // GetByUserID 获取用户的所有专属分组倍率 - // 返回 map[groupID]rateMultiplier + // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) - // GetByUserAndGroup 获取用户在特定分组的专属倍率 - // 如果未设置专属倍率,返回 nil + // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) - // GetByGroupID 获取指定分组下所有用户的专属倍率 + // GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) + GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) + + // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) - // SyncUserGroupRates 同步用户的分组专属倍率 - // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + // SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error - // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据) + // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error - // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + // SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。 + // 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。 + SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error + + // ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL) + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + + // DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用) DeleteByGroupID(ctx context.Context, groupID int64) error - // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + // DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用) DeleteByUserID(ctx context.Context, userID int64) error } diff --git a/backend/internal/service/user_rpm_cache.go b/backend/internal/service/user_rpm_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..b88573113e05d0fd23f2690ccc8aa453f641b7f2 --- /dev/null +++ b/backend/internal/service/user_rpm_cache.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserRPMCache 用户/分组级 RPM 计数器接口。 +// +// 与账号级 RPMCache 的区别: +// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。 +// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。 +// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。 +type UserRPMCache interface { + // IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。 + // 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。 + IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // IncrementUserRPM 原子递增用户级分钟计数并返回最新值。 + // 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。 + IncrementUserRPM(ctx context.Context, userID int64) (count int, err error) + + // GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。 + GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。 + GetUserRPM(ctx context.Context, userID int64) (count int, err error) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3490e8042148579c5673f2cbd14092bea454c6fe..a7279e6a7d7be587667686c154454e3501063aa4 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -1,30 +1,66 @@ package service import ( + "bytes" "context" + "crypto/sha256" "crypto/subtle" + "encoding/base64" + "encoding/hex" "fmt" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "image" + "image/color" + stddraw "image/draw" + _ "image/gif" + "image/jpeg" + _ "image/png" "log/slog" + "net/url" + "sort" + "strconv" "strings" + "sync" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + xdraw "golang.org/x/image/draw" + "golang.org/x/sync/singleflight" ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") - ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") + ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") + ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") + ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid") + ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid") + ErrIdentityUnbindLastMethod = infraerrors.Conflict( + "IDENTITY_UNBIND_LAST_METHOD", + "bind another sign-in method before unbinding this provider", + ) ) const ( - maxNotifyEmails = 3 // Maximum number of notification emails per user + maxNotifyEmails = 3 // Maximum number of notification emails per user + maxInlineAvatarBytes = 100 * 1024 + targetAvatarBytes = 20 * 1024 // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 notifyCodeUserRateWindow = 10 * time.Minute + + defaultUserIdentityRedirect = "/settings/profile" + userLastActiveMinTouch = 10 * time.Minute + userLastActiveFailBackoff = 30 * time.Second +) + +var ( + avatarScaleSteps = []float64{1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36} + avatarQualitySteps = []int{88, 80, 72, 64, 56, 48, 40, 32} ) // UserListFilters contains all filter options for listing users @@ -47,9 +83,15 @@ type UserRepository interface { GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error + GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) + UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + DeleteUserAvatar(ctx context.Context, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) + GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) + GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) + UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error @@ -60,6 +102,8 @@ type UserRepository interface { AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error // RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限 RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error + ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) + UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error @@ -67,15 +111,90 @@ type UserRepository interface { DisableTotp(ctx context.Context, userID int64) error } +type UserAuthIdentityRecord struct { + ProviderType string + ProviderKey string + ProviderSubject string + VerifiedAt *time.Time + Issuer *string + Metadata map[string]any + CreatedAt time.Time + UpdatedAt time.Time +} + +type UserIdentitySummary struct { + Provider string `json:"provider"` + Bound bool `json:"bound"` + BoundCount int `json:"bound_count"` + DisplayName string `json:"display_name,omitempty"` + AvatarURL string `json:"-"` + SubjectHint string `json:"subject_hint,omitempty"` + ProviderKey string `json:"provider_key,omitempty"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + BindStartPath string `json:"bind_start_path,omitempty"` + CanBind bool `json:"can_bind"` + CanUnbind bool `json:"can_unbind"` + NoteKey string `json:"note_key,omitempty"` + Note string `json:"note,omitempty"` +} + +type UserIdentitySummarySet struct { + Email UserIdentitySummary `json:"email"` + LinuxDo UserIdentitySummary `json:"linuxdo"` + OIDC UserIdentitySummary `json:"oidc"` + WeChat UserIdentitySummary `json:"wechat"` +} + +type StartUserIdentityBindingRequest struct { + Provider string + RedirectTo string +} + +type StartUserIdentityBindingResult struct { + Provider string `json:"provider"` + AuthorizeURL string `json:"authorize_url"` + Method string `json:"method"` + UseBrowserRedirect bool `json:"use_browser_redirect"` +} + +const ( + userIdentityNoteEmailManagedFromProfile = "profile.authBindings.notes.emailManagedFromProfile" + userIdentityNoteCanUnbind = "profile.authBindings.notes.canUnbind" + userIdentityNoteBindAnotherBeforeUnbind = "profile.authBindings.notes.bindAnotherBeforeUnbind" +) + // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` Concurrency *int `json:"concurrency"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type UserAvatar struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type UpsertUserAvatarInput struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type userProfileIdentityTxRunner interface { + WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -88,6 +207,8 @@ type UserService struct { settingRepo SettingRepository authCacheInvalidator APIKeyAuthCacheInvalidator billingCache BillingCache + lastActiveTouchL1 sync.Map + lastActiveTouchSF singleflight.Group } // NewUserService 创建用户服务实例 @@ -115,14 +236,176 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + normalizeLoadedUserTokenVersion(user) + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) { + if user == nil { + var err error + user, err = s.userRepo.GetByID(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err) + } + } + + records, err := s.listUserAuthIdentities(ctx, userID) + if err != nil { + return UserIdentitySummarySet{}, err + } + + summaries := UserIdentitySummarySet{ + Email: s.buildEmailIdentitySummary(user, records), + LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records), + OIDC: s.buildProviderIdentitySummary("oidc", user, records), + WeChat: s.buildProviderIdentitySummary("wechat", user, records), + } + + s.applyExplicitProviderAvailability(ctx, &summaries) + return summaries, nil +} + +func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, summaries *UserIdentitySummarySet) { + if s == nil || summaries == nil || s.settingRepo == nil { + return + } + + settings, err := s.settingRepo.GetMultiple(ctx, []string{ + SettingKeyLinuxDoConnectEnabled, + SettingKeyOIDCConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + }) + if err != nil { + return + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { + disableIdentityBindAction(&summaries.LinuxDo) + } + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { + disableIdentityBindAction(&summaries.OIDC) + } + if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok && strings.TrimSpace(raw) != "" { + if raw != "true" { + disableIdentityBindAction(&summaries.WeChat) + return + } + openEnabled, mpEnabled, _ := parseWeChatConnectCapabilitySettings(settings, true, settings[SettingKeyWeChatConnectMode]) + if !openEnabled && !mpEnabled { + disableIdentityBindAction(&summaries.WeChat) + } + } +} + +func disableIdentityBindAction(summary *UserIdentitySummary) { + if summary == nil || summary.Bound { + return + } + summary.CanBind = false + summary.BindStartPath = "" +} + +func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) { + provider := normalizeUserIdentityProvider(req.Provider) + if provider == "" { + return nil, ErrIdentityProviderInvalid + } + + authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo) + if err != nil { + return nil, err + } + + return &StartUserIdentityBindingResult{ + Provider: provider, + AuthorizeURL: authorizeURL, + Method: "GET", + UseBrowserRedirect: true, + }, nil +} + +func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) { + user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider) + return user, err +} + +func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) { + provider = normalizeUserIdentityProvider(provider) + if provider == "" || provider == "email" { + return nil, false, ErrIdentityProviderInvalid + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, false, fmt.Errorf("get user: %w", err) + } + + records, err := s.listUserAuthIdentities(ctx, userID) + if err != nil { + return nil, false, err + } + if len(filterUserAuthIdentities(records, provider)) == 0 { + return user, false, nil + } + if !s.canUnbindProvider(provider, user, records) { + return nil, false, ErrIdentityUnbindLastMethod + } + + if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil { + return nil, false, err + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + + updatedUser, err := s.GetProfile(ctx, userID) + if err != nil { + return nil, false, err + } + return updatedUser, true, nil +} + // UpdateProfile 更新用户资料 func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) { + if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok { + var ( + updated *User + oldConcurrency int + ) + if err := txRunner.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + var err error + updated, oldConcurrency, err = s.updateProfile(txCtx, userID, req) + return err + }); err != nil { + return nil, err + } + if s.authCacheInvalidator != nil && updated != nil && updated.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return updated, nil + } + + updated, oldConcurrency, err := s.updateProfile(ctx, userID, req) + if err != nil { + return nil, err + } + if s.authCacheInvalidator != nil && updated.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + return updated, nil +} + +func (s *UserService) updateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, int, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - return nil, fmt.Errorf("get user: %w", err) + return nil, 0, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency @@ -131,10 +414,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat // 检查新邮箱是否已被使用 exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email) if err != nil { - return nil, fmt.Errorf("check email exists: %w", err) + return nil, oldConcurrency, fmt.Errorf("check email exists: %w", err) } if exists && *req.Email != user.Email { - return nil, ErrEmailExists + return nil, oldConcurrency, ErrEmailExists } user.Email = *req.Email } @@ -143,6 +426,14 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.Username = *req.Username } + if req.AvatarURL != nil { + avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL) + if err != nil { + return nil, oldConcurrency, err + } + applyUserAvatar(user, avatar) + } + if req.Concurrency != nil { user.Concurrency = *req.Concurrency } @@ -159,13 +450,465 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat } if err := s.userRepo.Update(ctx, user); err != nil { - return nil, fmt.Errorf("update user: %w", err) + return nil, oldConcurrency, fmt.Errorf("update user: %w", err) } - if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { - s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + + return user, oldConcurrency, nil +} + +func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) { + avatarValue := strings.TrimSpace(raw) + if avatarValue == "" { + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + return nil, nil } - return user, nil + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + return avatar, nil +} + +func applyUserAvatar(user *User, avatar *UserAvatar) { + if user == nil { + return + } + if avatar == nil { + user.AvatarURL = "" + user.AvatarSource = "" + user.AvatarMIME = "" + user.AvatarByteSize = 0 + user.AvatarSHA256 = "" + return + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 +} + +func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.HasPrefix(raw, "data:") { + return normalizeInlineUserAvatarInput(raw) + } + + parsed, err := url.Parse(raw) + if err != nil || parsed == nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.TrimSpace(parsed.Host) == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + return UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: raw, + }, nil +} + +func ValidateUserAvatar(raw string) error { + _, err := normalizeUserAvatarInput(raw) + return err +} + +func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + body := strings.TrimPrefix(raw, "data:") + meta, encoded, ok := strings.Cut(body, ",") + if !ok { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + meta = strings.TrimSpace(meta) + encoded = strings.TrimSpace(encoded) + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")]) + if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return UpsertUserAvatarInput{}, ErrAvatarNotImage + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if len(decoded) > maxInlineAvatarBytes { + return UpsertUserAvatarInput{}, ErrAvatarTooLarge + } + + if len(decoded) > targetAvatarBytes { + decoded, contentType, err = compressInlineAvatar(decoded) + if err != nil { + return UpsertUserAvatarInput{}, err + } + raw = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(decoded) + } + + sum := sha256.Sum256(decoded) + return UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: raw, + ContentType: contentType, + ByteSize: len(decoded), + SHA256: hex.EncodeToString(sum[:]), + }, nil +} + +func compressInlineAvatar(decoded []byte) ([]byte, string, error) { + src, _, err := image.Decode(bytes.NewReader(decoded)) + if err != nil { + return nil, "", ErrAvatarInvalid + } + + srcBounds := src.Bounds() + if srcBounds.Empty() { + return nil, "", ErrAvatarInvalid + } + + for _, scale := range avatarScaleSteps { + width := max(1, int(float64(srcBounds.Dx())*scale)) + height := max(1, int(float64(srcBounds.Dy())*scale)) + dst := image.NewRGBA(image.Rect(0, 0, width, height)) + stddraw.Draw(dst, dst.Bounds(), &image.Uniform{C: color.White}, image.Point{}, stddraw.Src) + xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, srcBounds, stddraw.Over, nil) + + for _, quality := range avatarQualitySteps { + var buf bytes.Buffer + if err := jpeg.Encode(&buf, dst, &jpeg.Options{Quality: quality}); err != nil { + return nil, "", ErrAvatarInvalid + } + if buf.Len() <= targetAvatarBytes { + return buf.Bytes(), "image/jpeg", nil + } + } + } + + return nil, "", ErrAvatarTooLarge +} + +func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: "email", + CanBind: false, + CanUnbind: false, + NoteKey: userIdentityNoteEmailManagedFromProfile, + Note: "Primary account email is managed from the profile form.", + } + if user == nil { + return summary + } + + filtered := filterUserAuthIdentities(records, "email") + if len(filtered) > 0 { + primary := selectPrimaryUserAuthIdentity(filtered) + email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email")) + if email == "" { + email = strings.TrimSpace(primary.ProviderSubject) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(user.Email) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(primary.ProviderKey) + } + + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + return summary + } + + // Compatibility fallback for legacy normal-email users that predate auth_identities backfill. + email := strings.TrimSpace(user.Email) + if email == "" || isReservedEmail(email) { + return summary + } + summary.Bound = true + summary.BoundCount = 1 + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = "email" + return summary +} + +func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary { + summary := UserIdentitySummary{ + Provider: provider, + CanUnbind: false, + } + filtered := filterUserAuthIdentities(records, provider) + if len(filtered) == 0 { + summary.CanBind = true + bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "") + if err == nil { + summary.BindStartPath = bindStartPath + } + return summary + } + + primary := selectPrimaryUserAuthIdentity(filtered) + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = userAuthIdentityDisplayName(primary) + summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl")) + summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + summary.CanUnbind = s.canUnbindProvider(provider, user, records) + if summary.CanUnbind { + summary.NoteKey = userIdentityNoteCanUnbind + summary.Note = "You can unbind this sign-in method." + } else { + summary.NoteKey = userIdentityNoteBindAnotherBeforeUnbind + summary.Note = "Bind another sign-in method before unbinding." + } + return summary +} + +func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool { + if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 { + return false + } + + if s.canUseEmailAsSignInMethod(user, records) { + return true + } + + for _, candidate := range []string{"linuxdo", "oidc", "wechat"} { + if candidate == provider { + continue + } + if len(filterUserAuthIdentities(records, candidate)) > 0 { + return true + } + } + + return false +} + +func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool { + if user == nil { + return false + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return false + } + + if emailSignupSourceAllowsLogin(user.SignupSource) { + return true + } + + for _, record := range filterUserAuthIdentities(records, "email") { + if emailIdentitySupportsSignIn(record) { + return true + } + } + + return false +} + +func emailSignupSourceAllowsLogin(signupSource string) bool { + signupSource = strings.ToLower(strings.TrimSpace(signupSource)) + return signupSource == "" || signupSource == "email" +} + +func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool { + source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source")) + switch source { + case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write": + return true + default: + return false + } +} + +func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + if userID <= 0 || s == nil || s.userRepo == nil { + return nil, nil + } + return s.userRepo.ListUserAuthIdentities(ctx, userID) +} + +func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) { + provider = normalizeUserIdentityProvider(provider) + if provider == "" || provider == "email" { + return "", ErrIdentityProviderInvalid + } + + redirectTo, err := normalizeUserIdentityRedirect(redirectTo) + if err != nil { + return "", err + } + + path := "" + switch provider { + case "linuxdo": + path = "/api/v1/auth/oauth/linuxdo/bind/start" + case "oidc": + path = "/api/v1/auth/oauth/oidc/bind/start" + case "wechat": + path = "/api/v1/auth/oauth/wechat/bind/start" + default: + return "", ErrIdentityProviderInvalid + } + + query := url.Values{} + query.Set("redirect", redirectTo) + query.Set("intent", "bind_current_user") + return path + "?" + query.Encode(), nil +} + +func normalizeUserIdentityProvider(provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + case "email": + return "email" + default: + return "" + } +} + +func normalizeUserIdentityRedirect(raw string) (string, error) { + redirect := strings.TrimSpace(raw) + if redirect == "" { + return defaultUserIdentityRedirect, nil + } + if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { + return "", ErrIdentityRedirectInvalid + } + return redirect, nil +} + +func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord { + if len(records) == 0 { + return nil + } + filtered := make([]UserAuthIdentityRecord, 0, len(records)) + for _, record := range records { + if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) { + filtered = append(filtered, record) + } + } + return filtered +} + +func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord { + if len(records) == 0 { + return UserAuthIdentityRecord{} + } + sort.SliceStable(records, func(i, j int) bool { + left := userAuthIdentitySortTime(records[i]) + right := userAuthIdentitySortTime(records[j]) + if !left.Equal(right) { + return left.After(right) + } + return records[i].ProviderKey < records[j].ProviderKey + }) + return records[0] +} + +func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time { + if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() { + return record.VerifiedAt.UTC() + } + if !record.UpdatedAt.IsZero() { + return record.UpdatedAt.UTC() + } + if !record.CreatedAt.IsZero() { + return record.CreatedAt.UTC() + } + return time.Time{} +} + +func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string { + if displayName := firstStringIdentityValue(record.Metadata, + "display_name", + "suggested_display_name", + "username", + "name", + "nickname", + "email", + ); displayName != "" { + return displayName + } + if subject := strings.TrimSpace(record.ProviderSubject); subject != "" { + return subject + } + return strings.TrimSpace(record.ProviderType) +} + +func firstStringIdentityValue(values map[string]any, keys ...string) string { + for _, key := range keys { + raw, ok := values[key] + if !ok { + continue + } + switch value := raw.(type) { + case string: + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + case fmt.Stringer: + if trimmed := strings.TrimSpace(value.String()); trimmed != "" { + return trimmed + } + } + } + return "" +} + +func maskEmailIdentity(email string) string { + local, domain, ok := strings.Cut(strings.TrimSpace(email), "@") + if !ok || local == "" || domain == "" { + return maskOpaqueIdentity(email) + } + runes := []rune(local) + if len(runes) == 1 { + return string(runes[0]) + "***@" + domain + } + return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain +} + +func maskOpaqueIdentity(value string) string { + value = strings.TrimSpace(value) + runes := []rune(value) + switch { + case len(runes) == 0: + return "" + case len(runes) <= 4: + return string(runes[0]) + "***" + case len(runes) <= 8: + return string(runes[:2]) + "***" + string(runes[len(runes)-1:]) + default: + return string(runes[:3]) + "***" + string(runes[len(runes)-3:]) + } } // ChangePassword 修改密码 @@ -202,9 +945,94 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + normalizeLoadedUserTokenVersion(user) + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func normalizeLoadedUserTokenVersion(user *User) { + if user == nil || user.TokenVersionResolved { + return + } + user.TokenVersion = resolvedTokenVersion(user) + user.TokenVersionResolved = true +} + +// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。 +// 该操作为尽力而为,不应中断正常请求。 +func (s *UserService) TouchLastActive(ctx context.Context, userID int64) { + if s == nil || s.userRepo == nil || userID <= 0 { + return + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err) + return + } + s.TouchLastActiveForUser(ctx, user) +} + +// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。 +func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) { + if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 { + return + } + + now := time.Now() + if userLastActiveFresh(user.LastActiveAt, now) { + return + } + if v, ok := s.lastActiveTouchL1.Load(user.ID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return + } + } + + _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastActiveTouchL1.Load(user.ID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + if userLastActiveFresh(user.LastActiveAt, latest) { + return nil, nil + } + if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil { + s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff)) + return nil, fmt.Errorf("touch user last active: %w", err) + } + s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch)) + return nil, nil + }) + if err != nil { + slog.Warn("touch user last active failed", "user_id", user.ID, "error", err) + } +} + +func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool { + if lastActiveAt == nil { + return false + } + return now.Before(lastActiveAt.Add(userLastActiveMinTouch)) +} + +func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error { + if s == nil || s.userRepo == nil || user == nil || user.ID == 0 { + return nil + } + + avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + applyUserAvatar(user, avatar) + return nil +} + // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go new file mode 100644 index 0000000000000000000000000000000000000000..702b3b1a21503ecdd32f96ad87919788cb14f07b --- /dev/null +++ b/backend/internal/service/user_service_email_identity_sync_test.go @@ -0,0 +1,34 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + user: &User{ + ID: 19, + Email: "profile-before@example.com", + Username: "tester", + Concurrency: 2, + }, + replaceErr: context.DeadlineExceeded, + } + svc := NewUserService(repo, nil, nil, nil) + + newEmail := "profile-after@example.com" + updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{ + Email: &newEmail, + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, newEmail, updated.Email) + require.Equal(t, 1, repo.updateCalls) + require.Empty(t, repo.replaceCalls) + require.Empty(t, repo.ensureCalls) +} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index a998d5f443d06d3a2238ad6923b0d4afd74a5579..ff55c2a50abe2a2a83dc271a1a3898157c752839 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -3,8 +3,14 @@ package service import ( + "bytes" "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "errors" + "image" + "image/png" "sync" "sync/atomic" "testing" @@ -17,16 +23,159 @@ import ( // --- mock: UserRepository --- type mockUserRepo struct { - updateBalanceErr error - updateBalanceFn func(ctx context.Context, id int64, amount float64) error + updateBalanceErr error + updateBalanceFn func(ctx context.Context, id int64, amount float64) error + getByIDUser *User + getByIDErr error + identities []UserAuthIdentityRecord + unbindIdentityErr error + unboundProviders []string + updateLastActiveErr error + updateLastActiveUserIDs []int64 + updateLastActiveAt []time.Time + updateFn func(ctx context.Context, user *User) error + updateCalls int + upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarFn func(ctx context.Context, userID int64) error + deleteAvatarIDs []int64 + getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error) + txCalls int } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +type mockUserRepoTxKey struct{} + +type mockUserRepoTxState struct { + getByIDUser *User + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarIDs []int64 +} + +type mockUserSettingRepo struct { + values map[string]string +} + +func (m *mockUserSettingRepo) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (m *mockUserSettingRepo) GetValue(context.Context, string) (string, error) { + panic("unexpected GetValue call") +} + +func (m *mockUserSettingRepo) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (m *mockUserSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := m.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (m *mockUserSettingRepo) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (m *mockUserSettingRepo) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (m *mockUserSettingRepo) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) { + if m.getByIDErr != nil { + return nil, m.getByIDErr + } + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil && txState.getByIDUser != nil { + cloned := *txState.getByIDUser + return &cloned, nil + } + if m.getByIDUser != nil { + cloned := *m.getByIDUser + return &cloned, nil + } + return &User{}, nil +} func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Update(ctx context.Context, user *User) error { + m.updateCalls++ + if m.updateFn != nil { + return m.updateFn(ctx, user) + } + return nil +} +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + if m.getAvatarFn != nil { + return m.getAvatarFn(ctx, userID) + } + return nil, nil +} +func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil { + txState.upsertAvatarArgs = append(txState.upsertAvatarArgs, input) + if txState.getByIDUser != nil { + txState.getByIDUser.AvatarURL = input.URL + txState.getByIDUser.AvatarSource = input.StorageProvider + txState.getByIDUser.AvatarMIME = input.ContentType + txState.getByIDUser.AvatarByteSize = input.ByteSize + txState.getByIDUser.AvatarSHA256 = input.SHA256 + } + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil + } + m.upsertAvatarArgs = append(m.upsertAvatarArgs, input) + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil { + txState.deleteAvatarIDs = append(txState.deleteAvatarIDs, userID) + if txState.getByIDUser != nil { + txState.getByIDUser.AvatarURL = "" + txState.getByIDUser.AvatarSource = "" + txState.getByIDUser.AvatarMIME = "" + txState.getByIDUser.AvatarByteSize = 0 + txState.getByIDUser.AvatarSHA256 = "" + } + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil + } + m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID) + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil +} func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -39,6 +188,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float } return m.updateBalanceErr } +func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error { + m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID) + m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt) + return m.updateLastActiveErr +} func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil } func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil } func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } @@ -46,12 +200,58 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + out := make([]UserAuthIdentityRecord, len(m.identities)) + copy(out, m.identities) + return out, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } +func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error { + if m.unbindIdentityErr != nil { + return m.unbindIdentityErr + } + m.unboundProviders = append(m.unboundProviders, provider) + filtered := m.identities[:0] + for _, identity := range m.identities { + if identity.ProviderType == provider { + continue + } + filtered = append(filtered, identity) + } + m.identities = append([]UserAuthIdentityRecord(nil), filtered...) + return nil +} + +func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + m.txCalls++ + txState := &mockUserRepoTxState{ + upsertAvatarArgs: append([]UpsertUserAvatarInput(nil), m.upsertAvatarArgs...), + deleteAvatarIDs: append([]int64(nil), m.deleteAvatarIDs...), + } + if m.getByIDUser != nil { + userCopy := *m.getByIDUser + txState.getByIDUser = &userCopy + } + err := fn(context.WithValue(ctx, mockUserRepoTxKey{}, txState)) + if err != nil { + return err + } + m.getByIDUser = txState.getByIDUser + m.upsertAvatarArgs = txState.upsertAvatarArgs + m.deleteAvatarIDs = txState.deleteAvatarIDs + return nil +} // --- mock: APIKeyAuthCacheInvalidator --- @@ -132,6 +332,225 @@ func TestUpdateBalance_Success(t *testing.T) { require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存") } +func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-123456", + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser) + + require.NoError(t, err) + require.True(t, summaries.LinuxDo.Bound) + require.True(t, summaries.LinuxDo.CanUnbind) + require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName) + require.NotEmpty(t, summaries.LinuxDo.SubjectHint) +} + +func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "only-user@linuxdo-connect.invalid", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo") + + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + +func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "oauth-only@example.com", + SignupSource: "oidc", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "oidc-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser) + + require.NoError(t, err) + require.False(t, summaries.OIDC.CanUnbind) + + _, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc") + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + +func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 11, + Email: "oauth-only@example.com", + SignupSource: "wechat", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "oauth-only@example.com", + Metadata: map[string]any{ + "backfill_source": "users.email", + "migration": "109_auth_identity_compat_backfill", + }, + }, + { + ProviderType: "wechat", + ProviderKey: "wechat", + ProviderSubject: "wechat-only-subject", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser) + + require.NoError(t, err) + require.True(t, summaries.Email.Bound) + require.False(t, summaries.WeChat.CanUnbind) + + _, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat") + require.ErrorIs(t, err, ErrIdentityUnbindLastMethod) + require.Empty(t, repo.unboundProviders) +} + +func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-12", + }, + }, + } + invalidator := &mockAuthCacheInvalidator{} + svc := NewUserService(repo, nil, invalidator, nil) + + user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo") + + require.NoError(t, err) + require.Equal(t, []string{"linuxdo"}, repo.unboundProviders) + require.Equal(t, int64(12), user.ID) + require.Equal(t, []int64{12}, invalidator.invalidatedUserIDs) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user) + require.NoError(t, err) + require.False(t, summaries.LinuxDo.Bound) + require.True(t, summaries.LinuxDo.CanBind) +} + +func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabled(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 15, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + }, + } + settingRepo := &mockUserSettingRepo{ + values: map[string]string{ + SettingKeyLinuxDoConnectEnabled: "false", + }, + } + svc := NewUserService(repo, settingRepo, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 15, repo.getByIDUser) + + require.NoError(t, err) + require.False(t, summaries.LinuxDo.Bound) + require.False(t, summaries.LinuxDo.CanBind) + require.Empty(t, summaries.LinuxDo.BindStartPath) +} + +func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 16, + Email: "alice@example.com", + }, + identities: []UserAuthIdentityRecord{ + { + ProviderType: "email", + ProviderKey: "email", + ProviderSubject: "alice@example.com", + }, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser) + + require.NoError(t, err) + require.Equal( + t, + "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.LinuxDo.BindStartPath, + ) + require.Equal( + t, + "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.OIDC.BindStartPath, + ) + require.Equal( + t, + "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile", + summaries.WeChat.BindStartPath, + ) +} + func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { repo := &mockUserRepo{} svc := NewUserService(repo, nil, nil, nil) // billingCache = nil @@ -154,6 +573,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { }, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance") } +func TestTouchLastActive_UpdatesWhenStale(t *testing.T) { + stale := time.Now().Add(-11 * time.Minute) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 42, + LastActiveAt: &stale, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + svc.TouchLastActive(context.Background(), 42) + + require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs) + require.Len(t, repo.updateLastActiveAt, 1) + require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second) +} + +func TestTouchLastActive_SkipsWhenRecent(t *testing.T) { + recent := time.Now().Add(-time.Minute) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 42, + LastActiveAt: &recent, + }, + } + svc := NewUserService(repo, nil, nil, nil) + + svc.TouchLastActive(context.Background(), 42) + + require.Empty(t, repo.updateLastActiveUserIDs) + require.Empty(t, repo.updateLastActiveAt) +} + func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} cache := &mockBillingCache{} @@ -200,3 +652,199 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, cache, svc.billingCache) } + +func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) { + raw := []byte("small-avatar") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + expectedSum := sha256.Sum256(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "avatar@example.com", + Username: "avatar-user", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType) + require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256) + require.Equal(t, dataURL, updated.AvatarURL) + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/png", updated.AvatarMIME) + require.Equal(t, len(raw), updated.AvatarByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256) +} + +func TestUpdateProfile_CompressesInlineAvatarToTwentyKilobytes(t *testing.T) { + var encoded bytes.Buffer + for _, size := range []int{192, 224, 256, 288} { + encoded.Reset() + var img image.RGBA + img.Rect = image.Rect(0, 0, size, size) + img.Stride = size * 4 + img.Pix = make([]byte, size*size*4) + for y := 0; y < size; y++ { + for x := 0; x < size; x++ { + offset := y*img.Stride + x*4 + img.Pix[offset] = uint8((x*x + y*17) % 255) + img.Pix[offset+1] = uint8((y*y + x*29) % 255) + img.Pix[offset+2] = uint8(((x * y) + x*13 + y*7) % 255) + img.Pix[offset+3] = 0xff + } + } + require.NoError(t, png.Encode(&encoded, &img)) + if encoded.Len() > 20*1024 && encoded.Len() <= maxInlineAvatarBytes { + break + } + } + require.Greater(t, encoded.Len(), 20*1024) + require.LessOrEqual(t, encoded.Len(), maxInlineAvatarBytes) + + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(encoded.Bytes()) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 17, + Email: "avatar-compress@example.com", + Username: "avatar-compress", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 17, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.LessOrEqual(t, repo.upsertAvatarArgs[0].ByteSize, 20*1024) + require.Equal(t, "image/jpeg", repo.upsertAvatarArgs[0].ContentType) + require.Contains(t, repo.upsertAvatarArgs[0].URL, "data:image/jpeg;base64,") + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/jpeg", updated.AvatarMIME) + require.LessOrEqual(t, updated.AvatarByteSize, 20*1024) + require.Contains(t, updated.AvatarURL, "data:image/jpeg;base64,") + require.NotEmpty(t, updated.AvatarSHA256) +} + +func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) { + raw := make([]byte, maxInlineAvatarBytes+1) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 8, + Email: "large-avatar@example.com", + Username: "too-large", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.ErrorIs(t, err, ErrAvatarTooLarge) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Zero(t, repo.updateCalls) +} + +func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) { + remoteURL := "https://cdn.example.com/avatar.png" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "remote-avatar@example.com", + Username: "remote-avatar", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL) + require.Equal(t, remoteURL, updated.AvatarURL) + require.Equal(t, "remote_url", updated.AvatarSource) + require.Zero(t, updated.AvatarByteSize) +} + +func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) { + empty := "" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "delete-avatar@example.com", + Username: "delete-avatar", + AvatarURL: "https://cdn.example.com/old.png", + AvatarSource: "remote_url", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{ + AvatarURL: &empty, + }) + require.NoError(t, err) + require.Equal(t, []int64{10}, repo.deleteAvatarIDs) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, updated.AvatarURL) + require.Empty(t, updated.AvatarSource) +} + +func TestUpdateProfile_RollsBackAvatarMutationWhenUserUpdateFails(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 11, + Email: "rollback@example.com", + AvatarURL: "https://cdn.example.com/original.png", + AvatarSource: "remote_url", + }, + updateFn: func(context.Context, *User) error { + return errors.New("write user failed") + }, + } + svc := NewUserService(repo, nil, nil, nil) + + remoteURL := "https://cdn.example.com/new.png" + _, err := svc.UpdateProfile(context.Background(), 11, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + + require.EqualError(t, err, "update user: write user failed") + require.Equal(t, 1, repo.txCalls) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Equal(t, "https://cdn.example.com/original.png", repo.getByIDUser.AvatarURL) + require.Equal(t, "remote_url", repo.getByIDUser.AvatarSource) +} + +func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "profile-avatar@example.com", + Username: "profile-avatar", + }, + getAvatarFn: func(context.Context, int64) (*UserAvatar, error) { + return &UserAvatar{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/profile.png", + }, nil + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.GetProfile(context.Background(), 12) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL) + require.Equal(t, "remote_url", user.AvatarSource) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 9f33c46ab032ac374b142ff4395375d2c2e333e0..86bfc327a5b75d5eaf597ee5f2690297e5b35165 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { return NewEmailQueueService(emailService, 3) } +// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL. +func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { + return NewOAuthRefreshAPI(accountRepo, tokenCache) +} + // ProvideTokenRefreshService creates and starts TokenRefreshService func ProvideTokenRefreshService( accountRepo AccountRepository, @@ -210,11 +215,13 @@ func ProvideRateLimitService( geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache, timeoutCounterCache TimeoutCounterCache, + openAI403CounterCache OpenAI403CounterCache, settingService *SettingService, tokenCacheInvalidator TokenCacheInvalidator, ) *RateLimitService { svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc.SetTimeoutCounterCache(timeoutCounterCache) + svc.SetOpenAI403CounterCache(openAI403CounterCache) svc.SetSettingService(settingService) svc.SetTokenCacheInvalidator(tokenCacheInvalidator) return svc @@ -262,13 +269,16 @@ func ProvideOpsAlertEvaluatorService( } // ProvideOpsCleanupService creates and starts OpsCleanupService (cron scheduled). +// channelMonitorSvc 让维护任务(聚合 + 历史/聚合软删)跟随 ops 清理 cron 一起跑, +// 共享 leader lock + heartbeat。 func ProvideOpsCleanupService( opsRepo OpsRepository, db *sql.DB, redisClient *redis.Client, cfg *config.Config, + channelMonitorSvc *ChannelMonitorService, ) *OpsCleanupService { - svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg) + svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg, channelMonitorSvc) svc.Start() return svc } @@ -381,6 +391,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit return svc } +// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies. +func ProvideBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + rpmCache UserRPMCache, + rateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { + return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -397,7 +420,7 @@ var ProviderSet = wire.NewSet( NewDashboardService, ProvidePricingService, NewBillingService, - NewBillingCacheService, + ProvideBillingCacheService, NewAnnouncementService, NewAdminService, NewGatewayService, @@ -409,7 +432,7 @@ var ProviderSet = wire.NewSet( NewCompositeTokenCacheInvalidator, wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, - NewOAuthRefreshAPI, + ProvideOAuthRefreshAPI, ProvideGeminiTokenProvider, NewGeminiMessagesCompatService, ProvideAntigravityTokenProvider, @@ -467,6 +490,9 @@ var ProviderSet = wire.NewSet( NewPaymentService, ProvidePaymentOrderExpiryService, ProvideBalanceNotifyService, + ProvideChannelMonitorService, + ProvideChannelMonitorRunner, + NewChannelMonitorRequestTemplateService, ) // ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named @@ -486,3 +512,23 @@ func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderE svc.Start() return svc } + +// ProvideChannelMonitorService 创建渠道监控服务(CRUD + RunCheck + 用户视图聚合)。 +// 加密器复用 wire 中已注入的 SecretEncryptor(AES-256-GCM)。 +func ProvideChannelMonitorService( + repo ChannelMonitorRepository, + encryptor SecretEncryptor, +) *ChannelMonitorService { + return NewChannelMonitorService(repo, encryptor) +} + +// ProvideChannelMonitorRunner 创建并启动渠道监控调度器。 +// 通过 SetScheduler 注入回 service 后再 Start,确保启动时加载所有 enabled monitor, +// 后续 CRUD 也能即时同步任务表。Runner.Stop 由 cleanup function 调用。 +// settingService 用于 runner 每次 fire 读取功能开关。 +func ProvideChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner { + r := NewChannelMonitorRunner(svc, settingService) + svc.SetScheduler(r) + r.Start() + return r +} diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 89d09eeffa0a153684db6ead6a5c4ce4c84cdc77..2279d91320960aeaac751d926834092bbd89c36f 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -301,11 +301,13 @@ func shouldBypassEmbeddedFrontend(path string) bool { return strings.HasPrefix(trimmed, "/api/") || strings.HasPrefix(trimmed, "/v1/") || strings.HasPrefix(trimmed, "/v1beta/") || + strings.HasPrefix(trimmed, "/backend-api/") || strings.HasPrefix(trimmed, "/antigravity/") || strings.HasPrefix(trimmed, "/setup/") || trimmed == "/health" || trimmed == "/responses" || - strings.HasPrefix(trimmed, "/responses/") + strings.HasPrefix(trimmed, "/responses/") || + strings.HasPrefix(trimmed, "/images/") } func serveIndexHTML(c *gin.Context, fsys fs.FS) { diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index 4127a7a6e14fd5cc1719a91504f1d840b47cf0cc..583d98a0b9b1ba564190db6823d297286ed98034 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -434,6 +434,8 @@ func TestFrontendServer_Middleware(t *testing.T) { "/api/v1/users", "/v1/models", "/v1beta/chat", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", "/antigravity/test", "/setup/init", "/health", @@ -636,6 +638,8 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/api/users", "/v1/models", "/v1beta/chat", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", "/antigravity/test", "/setup/init", "/health", diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql new file mode 100644 index 0000000000000000000000000000000000000000..117e3ca38c5c11b00491c298344d9ada4e14650c --- /dev/null +++ b/backend/migrations/108_auth_identity_foundation_core.sql @@ -0,0 +1,141 @@ +ALTER TABLE users +ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email', +ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL; + +UPDATE users +SET signup_source = 'email' +WHERE signup_source IS NULL OR signup_source = ''; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'users_signup_source_check' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc')); + END IF; +END $$; + +CREATE TABLE IF NOT EXISTS auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + verified_at TIMESTAMPTZ NULL, + issuer TEXT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key + ON auth_identities (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx + ON auth_identities (user_id); + +CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx + ON auth_identities (user_id, provider_type); + +CREATE TABLE IF NOT EXISTS auth_identity_channels ( + id BIGSERIAL PRIMARY KEY, + identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + channel VARCHAR(20) NOT NULL, + channel_app_id TEXT NOT NULL, + channel_subject TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key + ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject); + +CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx + ON auth_identity_channels (identity_id); + +CREATE TABLE IF NOT EXISTS pending_auth_sessions ( + id BIGSERIAL PRIMARY KEY, + session_token VARCHAR(255) NOT NULL, + intent VARCHAR(40) NOT NULL, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + redirect_to TEXT NOT NULL DEFAULT '', + resolved_email TEXT NOT NULL DEFAULT '', + registration_password_hash TEXT NOT NULL DEFAULT '', + upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb, + local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb, + browser_session_key TEXT NOT NULL DEFAULT '', + completion_code_hash TEXT NOT NULL DEFAULT '', + completion_code_expires_at TIMESTAMPTZ NULL, + email_verified_at TIMESTAMPTZ NULL, + password_verified_at TIMESTAMPTZ NULL, + totp_verified_at TIMESTAMPTZ NULL, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT pending_auth_sessions_intent_check + CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')), + CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key + ON pending_auth_sessions (session_token); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx + ON pending_auth_sessions (target_user_id); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx + ON pending_auth_sessions (expires_at); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx + ON pending_auth_sessions (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx + ON pending_auth_sessions (completion_code_hash); + +CREATE TABLE IF NOT EXISTS identity_adoption_decisions ( + id BIGSERIAL PRIMARY KEY, + pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE, + identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL, + adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE, + adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE, + decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key + ON identity_adoption_decisions (pending_auth_session_id); + +CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx + ON identity_adoption_decisions (identity_id); + +CREATE TABLE IF NOT EXISTS auth_identity_migration_reports ( + id BIGSERIAL PRIMARY KEY, + report_type VARCHAR(40) NOT NULL, + report_key TEXT NOT NULL, + details JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx + ON auth_identity_migration_reports (report_type); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key + ON auth_identity_migration_reports (report_type, report_key); diff --git a/backend/migrations/108a_widen_auth_identity_migration_report_type.sql b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql new file mode 100644 index 0000000000000000000000000000000000000000..bc170fb84a3b13de409772c0f080eb3fe6a05146 --- /dev/null +++ b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql @@ -0,0 +1,14 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'auth_identity_migration_reports' + AND column_name = 'report_type' + AND COALESCE(character_maximum_length, 0) < 80 + ) THEN + ALTER TABLE auth_identity_migration_reports + ALTER COLUMN report_type TYPE VARCHAR(80); + END IF; +END $$; diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql new file mode 100644 index 0000000000000000000000000000000000000000..ddbbedbccbcdd9a8673122c5a42510e66cf6197a --- /dev/null +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -0,0 +1,125 @@ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'email', + 'email', + LOWER(BTRIM(u.email)), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'users.email', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(u.email, '')) <> '' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'linuxdo', + 'linuxdo', + SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'wechat', + 'wechat', + SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +UPDATE users +SET signup_source = 'linuxdo' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'; + +UPDATE users +SET signup_source = 'wechat' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$'; + +UPDATE users +SET signup_source = 'oidc' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$'; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'oidc_synthetic_email_requires_manual_recovery', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities ai + WHERE ai.user_id = u.id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + ) +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql new file mode 100644 index 0000000000000000000000000000000000000000..f59b2188c564acadf732eb9bfd5bd731ed65100e --- /dev/null +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -0,0 +1,59 @@ +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind', + granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')), + CONSTRAINT user_provider_default_grants_reason_check + CHECK (grant_reason IN ('signup', 'first_bind')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key + ON user_provider_default_grants (user_id, provider_type, grant_reason); + +CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx + ON user_provider_default_grants (user_id); + +CREATE TABLE IF NOT EXISTS user_avatars ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_provider VARCHAR(20) NOT NULL DEFAULT 'database', + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL DEFAULT '', + content_type VARCHAR(100) NOT NULL DEFAULT '', + byte_size INT NOT NULL DEFAULT 0, + sha256 VARCHAR(64) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key + ON user_avatars (user_id); + +INSERT INTO settings (key, value) +VALUES + ('auth_source_default_email_balance', '0'), + ('auth_source_default_email_concurrency', '5'), + ('auth_source_default_email_subscriptions', '[]'), + ('auth_source_default_email_grant_on_signup', 'false'), + ('auth_source_default_email_grant_on_first_bind', 'false'), + ('auth_source_default_linuxdo_balance', '0'), + ('auth_source_default_linuxdo_concurrency', '5'), + ('auth_source_default_linuxdo_subscriptions', '[]'), + ('auth_source_default_linuxdo_grant_on_signup', 'false'), + ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), + ('auth_source_default_oidc_balance', '0'), + ('auth_source_default_oidc_concurrency', '5'), + ('auth_source_default_oidc_subscriptions', '[]'), + ('auth_source_default_oidc_grant_on_signup', 'false'), + ('auth_source_default_oidc_grant_on_first_bind', 'false'), + ('auth_source_default_wechat_balance', '0'), + ('auth_source_default_wechat_concurrency', '5'), + ('auth_source_default_wechat_subscriptions', '[]'), + ('auth_source_default_wechat_grant_on_signup', 'false'), + ('auth_source_default_wechat_grant_on_first_bind', 'false'), + ('force_email_on_third_party_signup', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql new file mode 100644 index 0000000000000000000000000000000000000000..f222a8d40a9f18b376409ccb9587715eda637985 --- /dev/null +++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql @@ -0,0 +1,8 @@ +INSERT INTO settings (key, value) +VALUES + ('payment_visible_method_alipay_source', ''), + ('payment_visible_method_wxpay_source', ''), + ('payment_visible_method_alipay_enabled', 'false'), + ('payment_visible_method_wxpay_enabled', 'false'), + ('openai_advanced_scheduler_enabled', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql new file mode 100644 index 0000000000000000000000000000000000000000..d331b824419ab66810e65a3f2c801debc42f4709 --- /dev/null +++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql @@ -0,0 +1,10 @@ +ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30); + +UPDATE payment_orders +SET provider_key = ( + SELECT provider_key + FROM payment_provider_instances + WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id +) +WHERE provider_key IS NULL + AND provider_instance_id IS NOT NULL; diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql new file mode 100644 index 0000000000000000000000000000000000000000..15610af0d2a7f9d660ccb534c068d3a87c84619a --- /dev/null +++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql @@ -0,0 +1,89 @@ +UPDATE auth_identities AS ai +SET + provider_key = 'wechat-main', + metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject + ); + +UPDATE auth_identity_channels AS channel +SET + provider_key = 'wechat-main', + metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identity_channels AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject + ); + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_provider_key_conflict', + CAST(ai.id AS TEXT), + jsonb_build_object( + 'legacy_identity_id', ai.id, + 'legacy_user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'canonical_identity_id', canon.id, + 'canonical_user_id', canon.user_id, + 'same_user', canon.user_id = ai.user_id, + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identities AS ai +JOIN auth_identities AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_channel_provider_key_conflict', + CAST(channel.id AS TEXT), + jsonb_build_object( + 'legacy_channel_id', channel.id, + 'legacy_identity_id', channel.identity_id, + 'canonical_channel_id', canon.id, + 'canonical_identity_id', canon.identity_id, + 'channel', channel.channel, + 'channel_app_id', channel.channel_app_id, + 'channel_subject', channel.channel_subject, + 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE), + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identity_channels AS channel +JOIN auth_identity_channels AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject +LEFT JOIN auth_identities AS legacy_identity + ON legacy_identity.id = channel.identity_id +LEFT JOIN auth_identities AS canonical_identity + ON canonical_identity.id = canon.identity_id +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql new file mode 100644 index 0000000000000000000000000000000000000000..f84bf822921fc5135c5ab5d659302b40ad417617 --- /dev/null +++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql @@ -0,0 +1,11 @@ +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT ''; + +CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at + ON auth_identity_migration_reports (resolved_at); diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql new file mode 100644 index 0000000000000000000000000000000000000000..264da3c9b3f3540d72f12e5b18a57730a2bc9de5 --- /dev/null +++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql @@ -0,0 +1,268 @@ +CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT) +RETURNS JSONB +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN '{}'::jsonb; + END IF; + + BEGIN + parsed := input_text::jsonb; + EXCEPTION + WHEN OTHERS THEN + RETURN '{}'::jsonb; + END; + + IF jsonb_typeof(parsed) = 'object' THEN + RETURN parsed; + END IF; + + RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed); +END; +$$; + +DO $$ +BEGIN + IF to_regclass('public.user_external_identities') IS NULL THEN + RETURN; + END IF; + + EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_user_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_user_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_user_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_user_id + AND subjects.distinct_user_count = 1 +) +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + 'linuxdo', + 'linuxdo', + legacy.provider_user_id, + COALESCE(legacy.updated_at, legacy.created_at, NOW()), + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(uei.provider_username) AS provider_username, + BTRIM(uei.display_name) AS display_name, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + uei.created_at, + uei.updated_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_union_id + ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 +) +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + 'wechat', + 'wechat-main', + legacy.provider_union_id, + COALESCE(legacy.updated_at, legacy.created_at, NOW()), + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', legacy.provider_union_id, + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM canonical_legacy AS legacy +WHERE legacy.canonical_row_num = 1 +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + BTRIM(uei.provider_union_id) AS provider_union_id, + BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel, + BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id, + meta.metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + CROSS JOIN LATERAL ( + SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + ) AS meta + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' +), +legacy_subjects AS ( + SELECT + provider_union_id AS provider_subject, + COUNT(DISTINCT user_id) AS distinct_user_count + FROM legacy + GROUP BY provider_union_id +) +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +SELECT + ai.id, + 'wechat', + 'wechat-main', + legacy.channel, + legacy.channel_app_id, + legacy.provider_user_id, + legacy.metadata_json || jsonb_build_object( + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM legacy +JOIN legacy_subjects AS subjects + ON subjects.provider_subject = legacy.provider_union_id + AND subjects.distinct_user_count = 1 +JOIN auth_identities AS ai + ON ai.user_id = legacy.user_id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat-main' + AND ai.provider_subject = legacy.provider_union_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND legacy.provider_user_id <> '' +ON CONFLICT DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'user_id', legacy.user_id, + 'openid', legacy.provider_user_id, + 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline', + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(uei.provider_user_id) AS provider_user_id, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_union_id, '')) = '' +) AS legacy +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; +END $$; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'synthetic_auth_identity:' || ai.id::text, + COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'auth_identity_id', ai.id, + 'user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation', + 'migration', '115_auth_identity_legacy_external_backfill' + ) +FROM auth_identities AS ai +WHERE ai.provider_type = 'wechat' + AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email' + AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = '' +ON CONFLICT (report_type, report_key) DO NOTHING; + +DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT); diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql new file mode 100644 index 0000000000000000000000000000000000000000..81eb133cefd66a602e6d8db7191d5d8175625564 --- /dev/null +++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql @@ -0,0 +1,525 @@ +CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT) +RETURNS JSONB +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN '{}'::jsonb; + END IF; + + BEGIN + parsed := input_text::jsonb; + EXCEPTION + WHEN OTHERS THEN + RETURN '{}'::jsonb; + END; + + IF jsonb_typeof(parsed) = 'object' THEN + RETURN parsed; + END IF; + + RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed); +END; +$$; + +CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT) +RETURNS BOOLEAN +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN TRUE; + END IF; + + parsed := input_text::jsonb; + RETURN TRUE; +EXCEPTION + WHEN OTHERS THEN + RETURN FALSE; +END; +$$; + +DO $$ +BEGIN + IF to_regclass('public.user_external_identities') IS NULL THEN + RETURN; + END IF; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_invalid_metadata_json', + 'legacy_external_identity:' || uei.id::text, + jsonb_build_object( + 'legacy_identity_id', uei.id, + 'user_id', uei.user_id, + 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))), + 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')), + 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')), + 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object', + 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000), + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM user_external_identities AS uei +JOIN users AS u ON u.id = uei.user_id +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(uei.metadata, '')) <> '' + AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata) +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'provider_type', legacy.provider_type, + 'provider_key', legacy.provider_key, + 'provider_subject', legacy.provider_subject, + 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids, + 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject, + to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) > 1 +) AS ambiguous + ON ambiguous.provider_type = legacy.provider_type + AND ambiguous.provider_key = legacy.provider_key + AND ambiguous.provider_subject = legacy.provider_subject +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_identity_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'existing_identity_id', ai.id, + 'existing_user_id', ai.user_id, + 'provider_type', legacy.provider_type, + 'provider_key', legacy.provider_key, + 'provider_subject', legacy.provider_subject, + 'reason', 'legacy canonical identity subject already belongs to another user', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + BTRIM(COALESCE(uei.provider_username, '')) AS provider_username, + BTRIM(COALESCE(uei.display_name, '')) AS display_name, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +) AS legacy +JOIN ( + SELECT + provider_type, + provider_key, + provider_subject + FROM ( + SELECT + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) + ) AS legacy_subjects + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject +JOIN auth_identities AS ai + ON ai.provider_type = legacy.provider_type + AND ai.provider_key = legacy.provider_key + AND ai.provider_subject = legacy.provider_subject +WHERE ai.user_id <> legacy.user_id +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.id, + uei.user_id, + LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main' + ELSE 'linuxdo' + END AS provider_key, + CASE + WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, '')) + ELSE BTRIM(COALESCE(uei.provider_user_id, '')) + END AS provider_subject, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + BTRIM(COALESCE(uei.provider_username, '')) AS provider_username, + BTRIM(COALESCE(uei.display_name, '')) AS display_name, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat') + AND ( + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '') + OR + (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') + ) +), +clear_subjects AS ( + SELECT + provider_type, + provider_key, + provider_subject + FROM legacy + GROUP BY provider_type, provider_key, provider_subject + HAVING COUNT(DISTINCT user_id) = 1 +), +canonical_legacy AS ( + SELECT + legacy.*, + ROW_NUMBER() OVER ( + PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject + ORDER BY legacy.verified_at DESC, legacy.id DESC + ) AS canonical_row_num + FROM legacy + JOIN clear_subjects + ON clear_subjects.provider_type = legacy.provider_type + AND clear_subjects.provider_key = legacy.provider_key + AND clear_subjects.provider_subject = legacy.provider_subject +) +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + legacy.user_id, + legacy.provider_type, + legacy.provider_key, + legacy.provider_subject, + legacy.verified_at, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'provider_user_id', legacy.provider_user_id, + 'provider_union_id', NULLIF(legacy.provider_union_id, ''), + 'provider_username', legacy.provider_username, + 'display_name', legacy.display_name, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM canonical_legacy AS legacy +LEFT JOIN auth_identities AS ai + ON ai.provider_type = legacy.provider_type + AND ai.provider_key = legacy.provider_key + AND ai.provider_subject = legacy.provider_subject +WHERE legacy.canonical_row_num = 1 + AND ai.id IS NULL +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_external_channel_conflict', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'legacy_user_id', legacy.user_id, + 'existing_channel_id', channel.id, + 'existing_identity_id', existing_ai.id, + 'existing_user_id', existing_ai.user_id, + 'provider_type', 'wechat', + 'provider_key', 'wechat-main', + 'provider_subject', legacy.provider_union_id, + 'channel', legacy.channel, + 'channel_app_id', legacy.channel_app_id, + 'channel_subject', legacy.provider_user_id, + 'reason', 'legacy channel subject already belongs to another user', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +) AS legacy +JOIN ( + SELECT + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + GROUP BY BTRIM(COALESCE(uei.provider_union_id, '')) + HAVING COUNT(DISTINCT uei.user_id) = 1 +) AS clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id +JOIN auth_identities AS legacy_ai + ON legacy_ai.user_id = legacy.user_id + AND legacy_ai.provider_type = 'wechat' + AND legacy_ai.provider_key = 'wechat-main' + AND legacy_ai.provider_subject = legacy.provider_union_id +JOIN auth_identity_channels AS channel + ON channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = legacy.channel + AND channel.channel_app_id = legacy.channel_app_id + AND channel.channel_subject = legacy.provider_user_id +JOIN auth_identities AS existing_ai + ON existing_ai.id = channel.identity_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND existing_ai.user_id <> legacy.user_id +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; + + EXECUTE $sql$ +WITH legacy AS ( + SELECT + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, + BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel, + BTRIM(COALESCE( + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid', + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id', + '' + )) AS channel_app_id + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' +), +clear_subjects AS ( + SELECT + provider_union_id AS provider_subject + FROM legacy + GROUP BY provider_union_id + HAVING COUNT(DISTINCT user_id) = 1 +) +INSERT INTO auth_identity_channels ( + identity_id, + provider_type, + provider_key, + channel, + channel_app_id, + channel_subject, + metadata +) +SELECT + legacy_ai.id, + 'wechat', + 'wechat-main', + legacy.channel, + legacy.channel_app_id, + legacy.provider_user_id, + legacy.metadata_json || jsonb_build_object( + 'openid', legacy.provider_user_id, + 'unionid', legacy.provider_union_id, + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM legacy +JOIN clear_subjects + ON clear_subjects.provider_subject = legacy.provider_union_id +JOIN auth_identities AS legacy_ai + ON legacy_ai.user_id = legacy.user_id + AND legacy_ai.provider_type = 'wechat' + AND legacy_ai.provider_key = 'wechat-main' + AND legacy_ai.provider_subject = legacy.provider_union_id +LEFT JOIN auth_identity_channels AS channel + ON channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat-main' + AND channel.channel = legacy.channel + AND channel.channel_app_id = legacy.channel_app_id + AND channel.channel_subject = legacy.provider_user_id +WHERE legacy.channel <> '' + AND legacy.channel_app_id <> '' + AND channel.id IS NULL +ON CONFLICT DO NOTHING; +$sql$; + + EXECUTE $sql$ +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + 'legacy_external_identity:' || legacy.id::text, + legacy.metadata_json || jsonb_build_object( + 'legacy_identity_id', legacy.id, + 'user_id', legacy.user_id, + 'openid', legacy.provider_user_id, + 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline', + 'migration', '116_auth_identity_legacy_external_safety_reports' + ) +FROM ( + SELECT + uei.id, + uei.user_id, + BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id, + public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json + FROM user_external_identities AS uei + JOIN users AS u ON u.id = uei.user_id + WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' + AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' + AND BTRIM(COALESCE(uei.provider_union_id, '')) = '' +) AS legacy +ON CONFLICT (report_type, report_key) DO NOTHING; +$sql$; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identities_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identities + ADD CONSTRAINT auth_identities_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_channels_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identity_channels + ADD CONSTRAINT auth_identity_channels_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_migration_reports_details_is_object_check' + ) THEN + ALTER TABLE auth_identity_migration_reports + ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check + CHECK (jsonb_typeof(details) = 'object'); + END IF; +END $$; + +DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT); +DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT); diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql new file mode 100644 index 0000000000000000000000000000000000000000..56a5fe2dc504bf9fa34d04588206f9a43b66af43 --- /dev/null +++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql @@ -0,0 +1,2 @@ +ALTER TABLE payment_orders +ADD COLUMN IF NOT EXISTS provider_snapshot JSONB; diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql new file mode 100644 index 0000000000000000000000000000000000000000..187826179b411ea2489750f3c75203488c09e854 --- /dev/null +++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql @@ -0,0 +1,25 @@ +INSERT INTO settings (key, value) +VALUES + ( + 'wechat_connect_open_enabled', + CASE + WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN '' + WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' + WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false' + ELSE 'true' + END + ), + ( + 'wechat_connect_mp_enabled', + CASE + WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN '' + WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false' + WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true' + ELSE 'false' + END + ), + ('auth_source_default_email_grant_on_signup', 'false'), + ('auth_source_default_linuxdo_grant_on_signup', 'false'), + ('auth_source_default_oidc_grant_on_signup', 'false'), + ('auth_source_default_wechat_grant_on_signup', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql new file mode 100644 index 0000000000000000000000000000000000000000..15e2c15f4cef9ad14c43030a945dba766e0d7242 --- /dev/null +++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql @@ -0,0 +1,6 @@ +-- Intentionally left as a no-op. +-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql +DO $$ +BEGIN + NULL; +END $$; diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql new file mode 100644 index 0000000000000000000000000000000000000000..638d8622a368b71f599b9eb6d81b7288f2080dd1 --- /dev/null +++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql @@ -0,0 +1,10 @@ +-- Build the payment order uniqueness guarantee online. +-- The migration runner performs an explicit duplicate out_trade_no precheck and +-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying. +-- Create the new partial unique index concurrently first so writes keep flowing, +-- then remove the legacy index name once the replacement is ready. +CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique + ON payment_orders (out_trade_no) + WHERE out_trade_no <> ''; + +DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no; diff --git a/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql new file mode 100644 index 0000000000000000000000000000000000000000..ef2599dc2495b91b66e7795e77825edd05d8cc66 --- /dev/null +++ b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql @@ -0,0 +1,22 @@ +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'payment_orders' + AND indexname = 'paymentorder_out_trade_no_unique' + ) THEN + IF EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'payment_orders' + AND indexname = 'paymentorder_out_trade_no' + ) THEN + EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no'; + END IF; + + EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no'; + END IF; +END $$; diff --git a/backend/migrations/121_auth_identity_migration_report_type_widen.sql b/backend/migrations/121_auth_identity_migration_report_type_widen.sql new file mode 100644 index 0000000000000000000000000000000000000000..66bfb44a7b585651d9288927c6cc68be8676c5d9 --- /dev/null +++ b/backend/migrations/121_auth_identity_migration_report_type_widen.sql @@ -0,0 +1,2 @@ +ALTER TABLE auth_identity_migration_reports +ALTER COLUMN report_type TYPE VARCHAR(80); diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql new file mode 100644 index 0000000000000000000000000000000000000000..e634114298af4aaeb99cb280f78cb00b24f6e63f --- /dev/null +++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql @@ -0,0 +1,15 @@ +UPDATE pending_auth_sessions +SET + local_flow_state = jsonb_set( + local_flow_state, + '{completion_response}', + ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'), + true + ) +WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object' + AND ( + (local_flow_state -> 'completion_response') ? 'access_token' + OR (local_flow_state -> 'completion_response') ? 'refresh_token' + OR (local_flow_state -> 'completion_response') ? 'expires_in' + OR (local_flow_state -> 'completion_response') ? 'token_type' + ); diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql new file mode 100644 index 0000000000000000000000000000000000000000..4388285ac49d31bd4b54606ec4f0ebbf80ecb10b --- /dev/null +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -0,0 +1,68 @@ +-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value. +-- Rows still matching the migration-110 default payload and timestamp window are treated as +-- untouched legacy defaults; any remaining legacy true values are reported for manual review. + +WITH migration_110 AS ( + SELECT applied_at + FROM schema_migrations + WHERE filename = '110_pending_auth_and_provider_default_grants.sql' +), +providers AS ( + SELECT provider_type + FROM ( + VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') + ) AS providers(provider_type) +), +legacy_provider_defaults AS ( + SELECT providers.provider_type + FROM providers + CROSS JOIN migration_110 + JOIN settings balance + ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' + JOIN settings concurrency + ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' + JOIN settings subscriptions + ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' + JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' + JOIN settings grant_on_first_bind + ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' + WHERE balance.value = '0' + AND concurrency.value = '5' + AND subscriptions.value = '[]' + AND grant_on_signup.value = 'true' + AND grant_on_first_bind.value = 'false' + AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' +), +updated_signup_grants AS ( + UPDATE settings + SET + value = 'false', + updated_at = NOW() + FROM legacy_provider_defaults + WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' + AND settings.value = 'true' + RETURNING legacy_provider_defaults.provider_type +) +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'legacy_auth_source_signup_grant_review', + providers.provider_type, + jsonb_build_object( + 'provider_type', providers.provider_type, + 'current_value', grant_on_signup.value, + 'auto_backfilled', FALSE, + 'reason', 'legacy_true_default_not_auto_backfilled' + ) +FROM providers +JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' +LEFT JOIN updated_signup_grants + ON updated_signup_grants.provider_type = providers.provider_type +WHERE grant_on_signup.value = 'true' + AND updated_signup_grants.provider_type IS NULL +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/124_backfill_legacy_oidc_security_flags.sql b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql new file mode 100644 index 0000000000000000000000000000000000000000..e68bb11a499ff50bef18c5c2e28bec176a57c08a --- /dev/null +++ b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql @@ -0,0 +1,32 @@ +-- Preserve legacy OIDC behavior for upgraded installs that predate the +-- introduction of secure PKCE/id_token defaults. Fresh installs continue to +-- inherit runtime defaults when these rows are absent. + +WITH legacy_oidc_install AS ( + SELECT 1 + FROM settings + WHERE key IN ( + 'oidc_connect_enabled', + 'oidc_connect_client_id', + 'oidc_connect_authorize_url', + 'oidc_connect_token_url', + 'oidc_connect_issuer_url', + 'oidc_connect_userinfo_url', + 'oidc_connect_frontend_redirect_url' + ) + LIMIT 1 +) +INSERT INTO settings (key, value) +SELECT defaults.key, 'false' +FROM legacy_oidc_install +CROSS JOIN ( + VALUES + ('oidc_connect_use_pkce'), + ('oidc_connect_validate_id_token') +) AS defaults(key) +WHERE NOT EXISTS ( + SELECT 1 + FROM settings existing + WHERE existing.key = defaults.key +) +ON CONFLICT (key) DO NOTHING; diff --git a/backend/migrations/125_add_channel_monitors.sql b/backend/migrations/125_add_channel_monitors.sql new file mode 100644 index 0000000000000000000000000000000000000000..5ec327da819aff6ebb33a54128e4434a77e9bee8 --- /dev/null +++ b/backend/migrations/125_add_channel_monitors.sql @@ -0,0 +1,58 @@ +-- Migration: 125_add_channel_monitors +-- 渠道监控 MVP:周期性对外部 provider/endpoint/api_key 做模型心跳测试。 +-- +-- 表结构说明: +-- - channel_monitors 渠道配置表(一行 = 一个监控对象) +-- - channel_monitor_histories 检测历史明细表(一次检测一个模型 = 一行) +-- +-- 设计要点: +-- - api_key_encrypted 列存放 AES-256-GCM 密文(base64),由 service 层加密。 +-- - extra_models 用 JSONB 存储字符串数组,便于扩展(后续可加权重等元数据)。 +-- - history 表通过 ON DELETE CASCADE 自动清理已删除监控的历史。 +-- - (enabled, last_checked_at) 索引服务于调度器扫描“到期需要检测”的监控。 +-- - histories 上 (monitor_id, model, checked_at DESC) 服务用户视图聚合查询; +-- 单独的 (checked_at) 索引服务定期清理 30 天前数据的 DELETE。 + +CREATE TABLE IF NOT EXISTS channel_monitors ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + provider VARCHAR(20) NOT NULL, -- openai / anthropic / gemini + endpoint VARCHAR(500) NOT NULL, -- base origin + api_key_encrypted TEXT NOT NULL, -- AES-256-GCM (base64) + primary_model VARCHAR(200) NOT NULL, + extra_models JSONB NOT NULL DEFAULT '[]'::jsonb, + group_name VARCHAR(100) NOT NULL DEFAULT '', + enabled BOOLEAN NOT NULL DEFAULT TRUE, + interval_seconds INT NOT NULL, + last_checked_at TIMESTAMPTZ, + created_by BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT channel_monitors_provider_check CHECK (provider IN ('openai', 'anthropic', 'gemini')), + CONSTRAINT channel_monitors_interval_check CHECK (interval_seconds BETWEEN 15 AND 3600) +); + +CREATE INDEX IF NOT EXISTS idx_channel_monitors_enabled_last_checked + ON channel_monitors (enabled, last_checked_at); +CREATE INDEX IF NOT EXISTS idx_channel_monitors_provider + ON channel_monitors (provider); +CREATE INDEX IF NOT EXISTS idx_channel_monitors_group_name + ON channel_monitors (group_name); + +CREATE TABLE IF NOT EXISTS channel_monitor_histories ( + id BIGSERIAL PRIMARY KEY, + monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE, + model VARCHAR(200) NOT NULL, + status VARCHAR(20) NOT NULL, + latency_ms INT, + ping_latency_ms INT, + message VARCHAR(500) NOT NULL DEFAULT '', + checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT channel_monitor_histories_status_check + CHECK (status IN ('operational', 'degraded', 'failed', 'error')) +); + +CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_monitor_model_checked + ON channel_monitor_histories (monitor_id, model, checked_at DESC); +CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_checked_at + ON channel_monitor_histories (checked_at); diff --git a/backend/migrations/125_add_group_rpm_limit.sql b/backend/migrations/125_add_group_rpm_limit.sql new file mode 100644 index 0000000000000000000000000000000000000000..fbde1b20bbc422ec2e816565795edbd87170b50b --- /dev/null +++ b/backend/migrations/125_add_group_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-group Requests-Per-Minute limit. +-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。 +-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。 +-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。 +ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。'; diff --git a/backend/migrations/126_add_channel_monitor_aggregation.sql b/backend/migrations/126_add_channel_monitor_aggregation.sql new file mode 100644 index 0000000000000000000000000000000000000000..e643763cd3e737ed6c710c945e981593d774480b --- /dev/null +++ b/backend/migrations/126_add_channel_monitor_aggregation.sql @@ -0,0 +1,60 @@ +-- Migration: 126_add_channel_monitor_aggregation +-- 渠道监控日聚合:把 channel_monitor_histories 的明细按天聚合,明细只保留 1 天, +-- 聚合保留 30 天。明细和聚合表都用软删除(deleted_at),由 ops cleanup 任务每天 +-- 凌晨随运维监控清理一起跑(共享 cron)。 +-- +-- 设计要点: +-- - channel_monitor_histories 加 deleted_at 软删除字段(SoftDeleteMixin 全局 +-- Hook 会把 DELETE 自动改写成 UPDATE deleted_at = NOW())。 +-- - channel_monitor_daily_rollups 按 (monitor_id, model, bucket_date) 唯一, +-- 用 ON CONFLICT DO UPDATE 实现幂等回填,状态分布和延迟分子分母都保留, +-- 方便后续按窗口任意求加权可用率和均值。 +-- - watermark 表只有一行(id=1),记录最近一次聚合到达的日期,避免重启后重复 +-- 扫全表。 +-- - rollup 上 (bucket_date) 索引服务清理任务的 DELETE WHERE bucket_date < cutoff。 + +-- 1) 给历史明细表加软删除字段 +ALTER TABLE channel_monitor_histories + ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ; + +CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_deleted_at + ON channel_monitor_histories (deleted_at); + +-- 2) 创建日聚合表 +CREATE TABLE IF NOT EXISTS channel_monitor_daily_rollups ( + id BIGSERIAL PRIMARY KEY, + monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE, + model VARCHAR(200) NOT NULL, + bucket_date DATE NOT NULL, + total_checks INT NOT NULL DEFAULT 0, + ok_count INT NOT NULL DEFAULT 0, + operational_count INT NOT NULL DEFAULT 0, + degraded_count INT NOT NULL DEFAULT 0, + failed_count INT NOT NULL DEFAULT 0, + error_count INT NOT NULL DEFAULT 0, + sum_latency_ms BIGINT NOT NULL DEFAULT 0, + count_latency INT NOT NULL DEFAULT 0, + sum_ping_latency_ms BIGINT NOT NULL DEFAULT 0, + count_ping_latency INT NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_unique + ON channel_monitor_daily_rollups (monitor_id, model, bucket_date); +CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_bucket + ON channel_monitor_daily_rollups (bucket_date); +CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_deleted_at + ON channel_monitor_daily_rollups (deleted_at); + +-- 3) 创建 watermark 表(单行:id=1) +CREATE TABLE IF NOT EXISTS channel_monitor_aggregation_watermark ( + id INT PRIMARY KEY DEFAULT 1, + last_aggregated_date DATE, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT channel_monitor_aggregation_watermark_singleton CHECK (id = 1) +); + +INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at) +VALUES (1, NULL, NOW()) +ON CONFLICT (id) DO NOTHING; diff --git a/backend/migrations/126_add_user_rpm_limit.sql b/backend/migrations/126_add_user_rpm_limit.sql new file mode 100644 index 0000000000000000000000000000000000000000..64a8b97704b22b7631b7fd4c9b73f80ccfe791b3 --- /dev/null +++ b/backend/migrations/126_add_user_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-user Requests-Per-Minute cap. +-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。 +-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。 +-- 计数键:rpm:u:{user_id}:{minute}。 +ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。'; diff --git a/backend/migrations/127_add_user_group_rpm_override.sql b/backend/migrations/127_add_user_group_rpm_override.sql new file mode 100644 index 0000000000000000000000000000000000000000..1d67425834611026f989647a9709564b26b910bf --- /dev/null +++ b/backend/migrations/127_add_user_group_rpm_override.sql @@ -0,0 +1,16 @@ +-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空, +-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。 +-- 语义: +-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值 +-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率 +-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值 +-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制) +-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。 +ALTER TABLE user_group_rate_multipliers + ADD COLUMN IF NOT EXISTS rpm_override integer NULL; + +ALTER TABLE user_group_rate_multipliers + ALTER COLUMN rate_multiplier DROP NOT NULL; + +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。'; +COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。'; diff --git a/backend/migrations/127_drop_channel_monitor_deleted_at.sql b/backend/migrations/127_drop_channel_monitor_deleted_at.sql new file mode 100644 index 0000000000000000000000000000000000000000..2260f06b419db3e184180430c703ef215ccc2ade --- /dev/null +++ b/backend/migrations/127_drop_channel_monitor_deleted_at.sql @@ -0,0 +1,16 @@ +-- Migration: 127_drop_channel_monitor_deleted_at +-- 纠正 110 引入的 SoftDeleteMixin:日志/聚合表无恢复需求,软删会让行和索引只增不减, +-- 徒增磁盘和查询开销。改回分批物理删(由 OpsCleanupService 每天凌晨统一调度, +-- deleteOldRowsByID 模板,batch=5000)。 +-- +-- 110 尚未跑过聚合/清理(首次 maintenance 在次日 02:00),所以此处不担心业务数据。 +-- 直接 DROP 列 + 索引;对应的 Go 侧 ent schema 已移除 SoftDeleteMixin、repo 的 +-- raw SQL 已移除 deleted_at IS NULL 过滤。 + +DROP INDEX IF EXISTS idx_channel_monitor_histories_deleted_at; +ALTER TABLE channel_monitor_histories + DROP COLUMN IF EXISTS deleted_at; + +DROP INDEX IF EXISTS idx_channel_monitor_daily_rollups_deleted_at; +ALTER TABLE channel_monitor_daily_rollups + DROP COLUMN IF EXISTS deleted_at; diff --git a/backend/migrations/128_add_channel_monitor_request_templates.sql b/backend/migrations/128_add_channel_monitor_request_templates.sql new file mode 100644 index 0000000000000000000000000000000000000000..2db8fef69b062958d8d4a90d7efde704acb579f9 --- /dev/null +++ b/backend/migrations/128_add_channel_monitor_request_templates.sql @@ -0,0 +1,70 @@ +-- Migration: 128_add_channel_monitor_request_templates +-- 加请求模板表 + 给 channel_monitors 加 4 个快照字段(template_id 关联引用 + extra_headers / +-- body_override_mode / body_override 三个真正运行时使用的快照)。 +-- +-- 设计要点: +-- 1) 模板与监控之间是「应用即拷贝」的快照语义,运行时 checker 不再回查模板表。 +-- 模板 UPDATE 不会自动影响监控;只有用户主动「应用到关联监控」才会刷新快照。 +-- 2) ON DELETE SET NULL:模板删除不级联清理监控;监控保留快照继续工作。 +-- 3) extra_headers / body_override 都是 JSONB;body_override_mode 用 varchar(不是 enum) +-- 便于将来加新模式无需 ALTER TYPE。 +-- 4) 同一 provider 内模板 name 唯一(允许 Anthropic + OpenAI 重名 "伪装官方客户端")。 + +CREATE TABLE IF NOT EXISTS channel_monitor_request_templates ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + provider VARCHAR(20) NOT NULL, + description VARCHAR(500) NOT NULL DEFAULT '', + extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb, + body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off', + body_override JSONB NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT channel_monitor_request_templates_provider_check + CHECK (provider IN ('openai', 'anthropic', 'gemini')), + CONSTRAINT channel_monitor_request_templates_body_mode_check + CHECK (body_override_mode IN ('off', 'merge', 'replace')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS channel_monitor_request_templates_provider_name + ON channel_monitor_request_templates (provider, name); + +-- channel_monitors 加 4 列(ADD COLUMN IF NOT EXISTS 需要 PG 9.6+,生产使用 PG 16) +ALTER TABLE channel_monitors + ADD COLUMN IF NOT EXISTS template_id BIGINT NULL; +ALTER TABLE channel_monitors + ADD COLUMN IF NOT EXISTS extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb; +ALTER TABLE channel_monitors + ADD COLUMN IF NOT EXISTS body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off'; +ALTER TABLE channel_monitors + ADD COLUMN IF NOT EXISTS body_override JSONB NULL; + +-- 约束 + 外键(DO 块里 IF NOT EXISTS 判断,保证幂等) +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.table_constraints + WHERE constraint_name = 'channel_monitors_body_mode_check' + AND table_name = 'channel_monitors' + ) THEN + ALTER TABLE channel_monitors + ADD CONSTRAINT channel_monitors_body_mode_check + CHECK (body_override_mode IN ('off', 'merge', 'replace')); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM information_schema.table_constraints + WHERE constraint_name = 'channel_monitors_template_id_fkey' + AND table_name = 'channel_monitors' + ) THEN + ALTER TABLE channel_monitors + ADD CONSTRAINT channel_monitors_template_id_fkey + FOREIGN KEY (template_id) + REFERENCES channel_monitor_request_templates (id) + ON DELETE SET NULL; + END IF; +END $$; + +CREATE INDEX IF NOT EXISTS idx_channel_monitors_template_id + ON channel_monitors (template_id) + WHERE template_id IS NOT NULL; diff --git a/backend/migrations/129_seed_claude_code_template.sql b/backend/migrations/129_seed_claude_code_template.sql new file mode 100644 index 0000000000000000000000000000000000000000..d9b062c907edd736746e83506d43fc7add5dc6be --- /dev/null +++ b/backend/migrations/129_seed_claude_code_template.sql @@ -0,0 +1,38 @@ +-- Migration: 129_seed_claude_code_template +-- 内置「Claude Code 伪装」请求模板,覆盖 Anthropic 上游对官方 CLI 客户端的所有验证项: +-- 1) User-Agent / X-App / anthropic-beta / anthropic-version 等头 +-- 2) system 数组首项与官方 system prompt 字面一致(Dice >= 0.5) +-- 3) metadata.user_id 满足 ParseMetadataUserID — 这里用 legacy 格式(user_<64hex>_account__session_<36char>) +-- 避免新版 JSON 字符串内嵌 JSON 在编辑器里出现一长串 \" 转义,便于用户阅读。 +-- +-- ON CONFLICT DO NOTHING:已部署环境(手动建过模板)跑此 migration 不会重复 / 覆盖。 +-- 用户可自行编辑后续覆盖此 seed;CC 升大版时再起一条 migration 提供新模板,不动用户的旧模板。 + +INSERT INTO channel_monitor_request_templates ( + name, provider, description, extra_headers, body_override_mode, body_override +) +VALUES ( + 'Claude Code 伪装', + 'anthropic', + '完整模拟 Claude Code 2.1.114 客户端:UA + anthropic-beta + system + metadata.user_id 全部对齐,绕过 Anthropic 上游 ''Claude Code only'' 限制(如 Max 套餐)。', + '{ + "User-Agent": "claude-cli/2.1.114 (external, sdk-cli)", + "X-App": "cli", + "anthropic-version": "2023-06-01", + "anthropic-beta": "claude-code-20250219,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,advisor-tool-2026-03-01", + "anthropic-dangerous-direct-browser-access": "true" + }'::jsonb, + 'merge', + '{ + "system": [ + { + "type": "text", + "text": "You are Claude Code, Anthropic''s official CLI for Claude." + } + ], + "metadata": { + "user_id": "user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000" + } + }'::jsonb +) +ON CONFLICT (provider, name) DO NOTHING; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go new file mode 100644 index 0000000000000000000000000000000000000000..798ae0fe078f17d54adf45645332fd42f5034b15 --- /dev/null +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -0,0 +1,129 @@ +package migrations + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMigration112UsesIdempotentAddColumn(t *testing.T) { + content, err := FS.ReadFile("112_add_payment_order_provider_key_snapshot.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30)") + require.NotContains(t, sql, "ADD COLUMN provider_key VARCHAR(30);") +} + +func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) { + content, err := FS.ReadFile("118_wechat_dual_mode_and_auth_source_defaults.sql") + require.NoError(t, err) + + sql := string(content) + require.NotContains(t, sql, "UPDATE settings") + require.NotContains(t, sql, "SET value = 'false'") + require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) + require.Contains(t, sql, "THEN ''") +} + +func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) { + preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql") + require.NoError(t, err) + + preflightSQL := string(preflightContent) + require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports") + require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)") + + content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql") + require.NoError(t, err) + + sql := string(content) + require.NotContains(t, sql, "ALTER TABLE auth_identity_migration_reports") + + followupContent, err := FS.ReadFile("121_auth_identity_migration_report_type_widen.sql") + require.NoError(t, err) + + followupSQL := string(followupContent) + require.Contains(t, followupSQL, "ALTER TABLE auth_identity_migration_reports") + require.Contains(t, followupSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)") +} + +func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { + content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "120_enforce_payment_orders_out_trade_no_unique_notx.sql") + require.Contains(t, sql, "NULL;") + require.NotContains(t, sql, "CREATE UNIQUE INDEX") + require.NotContains(t, sql, "DROP INDEX") + + followupContent, err := FS.ReadFile("120_enforce_payment_orders_out_trade_no_unique_notx.sql") + require.NoError(t, err) + + followupSQL := string(followupContent) + require.Contains(t, followupSQL, "explicit duplicate out_trade_no precheck") + require.Contains(t, followupSQL, "stale invalid paymentorder_out_trade_no_unique index") + require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") + require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique") + require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") + require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") + + alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql") + require.NoError(t, err) + + alignmentSQL := string(alignmentContent) + require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique") + require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no") +} + +func TestMigration110SeedsAuthSourceSignupGrantsDisabledByDefault(t *testing.T) { + content, err := FS.ReadFile("110_pending_auth_and_provider_default_grants.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "('auth_source_default_email_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_linuxdo_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_oidc_grant_on_signup', 'false')") + require.Contains(t, sql, "('auth_source_default_wechat_grant_on_signup', 'false')") + require.NotContains(t, sql, "('auth_source_default_email_grant_on_signup', 'true')") +} + +func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { + content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "UPDATE pending_auth_sessions") + require.Contains(t, sql, "completion_response") + require.Contains(t, sql, "access_token") + require.Contains(t, sql, "refresh_token") + require.Contains(t, sql, "expires_in") + require.Contains(t, sql, "token_type") +} + +func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) { + content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql") + require.Contains(t, sql, "schema_migrations") + require.Contains(t, sql, "updated_at") + require.Contains(t, sql, "'_grant_on_signup'") + require.Contains(t, sql, "value = 'false'") + require.Contains(t, sql, "auth_identity_migration_reports") +} + +func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) { + content, err := FS.ReadFile("124_backfill_legacy_oidc_security_flags.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "oidc_connect_use_pkce") + require.Contains(t, sql, "oidc_connect_validate_id_token") + require.Contains(t, sql, "ON CONFLICT (key) DO NOTHING") + require.Contains(t, sql, "oidc_connect_enabled") + require.Contains(t, sql, "'false'") +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 358f6a31d9db662dbe12fa74d9a832ccbe5a16ee..dfc363b5a5e706f8a9637585694665682e2b1cd8 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -841,7 +841,7 @@ linuxdo_connect: frontend_redirect_url: "/auth/linuxdo/callback" token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE - use_pkce: false + use_pkce: true userinfo_email_path: "" userinfo_id_path: "" userinfo_username_path: "" diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md index b66a791c5f8392174cf2ea1b821c6603bfec10a5..af93fa7ee17c1ce03110d9b820c8dca5dc246285 100644 --- a/docs/PAYMENT.md +++ b/docs/PAYMENT.md @@ -22,13 +22,18 @@ Sub2API has a built-in payment system that enables user self-service top-up with | Provider | Payment Methods | Description | |----------|----------------|-------------| | **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol | -| **Alipay (Direct)** | PC Page Pay, H5 Mobile Pay | Direct integration with Alipay Open Platform, auto-switches by device | -| **WeChat Pay (Direct)** | Native QR Code, H5 Pay | Direct integration with WeChat Pay APIv3, mobile-first H5 | +| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links | +| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing | | **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support | -> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup. +> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup. -> **EasyPay Recommendation**: [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`) is recommended as an EasyPay provider (link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it). ZPay supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. +> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need: +> +> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it. +> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer. +> +> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. --- @@ -56,9 +61,18 @@ Configure the following in Admin Dashboard **Settings → Payment Settings**: | **Minimum Amount** | Minimum single top-up amount | 1 | | **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - | | **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - | -| **Order Timeout** | Order timeout in minutes (minimum 1) | 5 | +| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 | | **Max Pending Orders** | Maximum concurrent pending orders per user | 3 | -| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount | +| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin | + +### Frontend Visible Method Routing + +The current payment UX keeps the frontend method list unified and does not expose provider brands directly: + +- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay` +- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat` +- Each visible method can route to only one source at a time +- If a visible method is enabled without a selected source, the frontend will not expose that method ### Load Balance Strategies @@ -108,7 +122,7 @@ Compatible with any payment service that implements the EasyPay protocol. ### Alipay (Direct) -Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile pay. +Direct integration with Alipay Open Platform. Mobile flows return an Alipay WAP/app redirect URL. Desktop flows prefer Face-to-Face Precreate QR payloads; if the merchant has not enabled that product, the provider falls back to Computer Website Pay and also returns the cashier URL so the frontend can render a QR code or open the hosted checkout page directly. | Parameter | Description | Required | |-----------|-------------|----------| @@ -118,7 +132,7 @@ Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile ### WeChat Pay (Direct) -Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment. +Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment. | Parameter | Description | Required | |-----------|-------------|----------| @@ -127,8 +141,8 @@ Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment | **Merchant API Private Key** | Merchant API private key (PEM format) | Yes | | **APIv3 Key** | 32-byte APIv3 key | Yes | | **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes | -| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No | -| **Certificate Serial Number** | Merchant certificate serial number | No | +| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes | +| **Certificate Serial Number** | Merchant certificate serial number | Yes | ### Stripe @@ -215,8 +229,8 @@ User selects amount and payment method ▼ User completes payment ├─ EasyPay → QR code / H5 redirect - ├─ Alipay → PC page pay / H5 mobile pay - ├─ WeChat Pay → Native QR / H5 pay + ├─ Alipay → Desktop QR payload (Face-to-Face preferred, Website Pay fallback) / mobile Alipay redirect + ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI └─ Stripe → Payment Element (card/Alipay/WeChat/etc.) │ ▼ diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md index 9d96557f68356dbc16a9041483e48b6a744f5fcb..ae765fb96bcac0453418f51adec23b9bce2365e1 100644 --- a/docs/PAYMENT_CN.md +++ b/docs/PAYMENT_CN.md @@ -22,13 +22,18 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 | 服务商 | 支付方式 | 说明 | |--------|---------|------| | **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 | -| **支付宝官方** | 支付宝 PC 页面支付、H5 手机网站支付 | 直接对接支付宝开放平台,自动根据终端切换 | -| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 | +| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 | +| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 | | **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 | -> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。 +> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。 -> **EasyPay 推荐**:个人推荐 [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`)作为 EasyPay 服务商(链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉)。ZPay 支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 +> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择: +> +> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。 +> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。 +> +> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 --- @@ -56,9 +61,18 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 | **最低金额** | 单笔最低充值金额 | 1 | | **最高金额** | 单笔最高充值金额(留空表示不限制) | - | | **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - | -| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 5 | +| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 | | **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 | -| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 | +| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 | + +### 前台可见支付方式路由 + +当前版本对用户统一展示支付方式,不区分官方渠道还是易支付: + +- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝` +- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信` +- 同一个可见支付方式在同一时刻只能路由到一个来源 +- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式 ### 负载均衡策略 @@ -108,7 +122,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 ### 支付宝官方 -直接对接支付宝开放平台,支持 PC 页面支付和 H5 手机网站支付。 +直接对接支付宝开放平台。移动端走支付宝手机网站支付跳转;桌面端优先使用当面付返回扫码串,若商户未开通当面付则回退到电脑网站支付,并将收银台链接同时返回给前端用于渲染二维码或直接打开支付页。 | 参数 | 说明 | 必填 | |------|------|------| @@ -118,7 +132,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 ### 微信官方 -直接对接微信支付 APIv3,支持 Native 扫码支付和 H5 支付。 +直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。 | 参数 | 说明 | 必填 | |------|------|------| @@ -127,8 +141,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 | **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 | | **APIv3 密钥** | 32 位 APIv3 密钥 | 是 | | **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 | -| **微信支付公钥 ID** | 微信支付公钥 ID | 否 | -| **商户证书序列号** | 商户证书序列号 | 否 | +| **微信支付公钥 ID** | 微信支付公钥 ID | 是 | +| **商户证书序列号** | 商户证书序列号 | 是 | ### Stripe @@ -215,8 +229,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 ▼ 用户完成支付 ├─ EasyPay → 扫码 / H5 跳转 - ├─ 支付宝官方 → PC 页面支付 / H5 手机网站支付 - ├─ 微信官方 → Native 扫码 / H5 支付 + ├─ 支付宝官方 → 桌面扫码单(当面付优先,电脑网站支付回退)/ 移动端支付宝跳转 + ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI └─ Stripe → Payment Element(银行卡/支付宝/微信等) │ ▼ diff --git a/frontend/src/api/__tests__/admin.users.spec.ts b/frontend/src/api/__tests__/admin.users.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..37656b78acec3ea259ff08ad8acfff06707b2e7d --- /dev/null +++ b/frontend/src/api/__tests__/admin.users.spec.ts @@ -0,0 +1,117 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { post } = vi.hoisted(() => ({ + post: vi.fn(), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post, + }, +})) + +import { + bindUserAuthIdentity, + type AdminBindAuthIdentityRequest, + type AdminBoundAuthIdentity, +} from '@/api/admin/users' + +type Assert = T +type IsExact = ( + (() => G extends T ? 1 : 2) extends (() => G extends U ? 1 : 2) + ? ((() => G extends U ? 1 : 2) extends (() => G extends T ? 1 : 2) ? true : false) + : false +) + +type ExpectedAdminBindAuthIdentityRequest = { + provider_type: string + provider_key: string + provider_subject: string + issuer?: string + metadata?: Record + channel?: { + channel: string + channel_app_id: string + channel_subject: string + metadata?: Record + } +} + +type ExpectedAdminBoundAuthIdentity = { + user_id: number + provider_type: string + provider_key: string + provider_subject: string + verified_at?: string | null + issuer?: string | null + metadata: Record | null + created_at: string + updated_at: string + channel?: { + channel: string + channel_app_id: string + channel_subject: string + metadata: Record | null + created_at: string + updated_at: string + } | null +} + +const requestContractExact: Assert< + IsExact +> = true +const responseContractExact: Assert< + IsExact +> = true + +describe('admin users api auth identity binding', () => { + beforeEach(() => { + post.mockReset() + }) + + it('posts the backend-compatible auth identity bind payload and returns the backend response shape', async () => { + const payload: AdminBindAuthIdentityRequest = { + provider_type: 'wechat', + provider_key: 'wechat-main', + provider_subject: 'union-123', + metadata: { source: 'admin-repair' }, + channel: { + channel: 'open', + channel_app_id: 'wx-open', + channel_subject: 'openid-123', + metadata: { scene: 'migration' }, + }, + } + + const response: AdminBoundAuthIdentity = { + user_id: 9, + provider_type: 'wechat', + provider_key: 'wechat-main', + provider_subject: 'union-123', + verified_at: '2026-04-22T00:00:00Z', + issuer: null, + metadata: { source: 'admin-repair' }, + created_at: '2026-04-22T00:00:00Z', + updated_at: '2026-04-22T00:00:00Z', + channel: { + channel: 'open', + channel_app_id: 'wx-open', + channel_subject: 'openid-123', + metadata: { scene: 'migration' }, + created_at: '2026-04-22T00:00:00Z', + updated_at: '2026-04-22T00:00:00Z', + }, + } + post.mockResolvedValue({ data: response }) + + const result = await bindUserAuthIdentity(9, payload) + + expect(post).toHaveBeenCalledWith('/admin/users/9/auth-identities', payload) + expect(result).toEqual(response) + }) + + it('keeps bind auth identity request and response types aligned with the backend contract', () => { + expect(requestContractExact).toBe(true) + expect(responseContractExact).toBe(true) + }) +}) diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..a484d7ed978db8d3738f6a111d86a0077405b2c7 --- /dev/null +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -0,0 +1,184 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const post = vi.fn() + +vi.mock('@/api/client', () => ({ + apiClient: { + post + } +})) + +describe('oauth adoption auth api', () => { + beforeEach(() => { + post.mockReset() + post.mockResolvedValue({ data: {} }) + localStorage.clear() + document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/' + }) + + it('posts adoption decisions when exchanging pending oauth completion', async () => { + const { exchangePendingOAuthCompletion } = await import('@/api/auth') + + await exchangePendingOAuthCompletion({ + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts bind-login decisions when finalizing pending oauth bind flow', async () => { + const { completePendingOAuthBindLogin } = await import('@/api/auth') + + await completePendingOAuthBindLogin({ + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts linuxdo invitation completion with adoption decisions', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts linuxdo create-account completion with adoption decisions', async () => { + const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth') + + await createPendingLinuxDoOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts oidc invitation completion with adoption decisions', async () => { + const { completeOIDCOAuthRegistration } = await import('@/api/auth') + + await completeOIDCOAuthRegistration('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts oidc create-account completion with adoption decisions', async () => { + const { createPendingOIDCOAuthAccount } = await import('@/api/auth') + + await createPendingOIDCOAuthAccount('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts wechat invitation completion with adoption decisions', async () => { + const { completeWeChatOAuthRegistration } = await import('@/api/auth') + + await completeWeChatOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: true + }) + }) + + it('posts wechat create-account completion with adoption decisions', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: false + }) + }) + + it('classifies oauth completion results as login or bind', async () => { + const { getOAuthCompletionKind } = await import('@/api/auth') + + expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login') + expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind') + }) + + it('provides bind-login utility helpers for invitation and suggested profile states', async () => { + const { + getPendingOAuthBindLoginKind, + hasPendingOAuthSuggestedProfile, + isPendingOAuthCreateAccountRequired + } = await import('@/api/auth') + + expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login') + expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind') + expect( + isPendingOAuthCreateAccountRequired({ + error: 'invitation_required' + }) + ).toBe(true) + expect( + isPendingOAuthCreateAccountRequired({ + error: 'other' + }) + ).toBe(false) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_display_name: 'OAuth Nick' + }) + ).toBe(true) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_avatar_url: 'https://cdn.example/avatar.png' + }) + ).toBe(true) + expect(hasPendingOAuthSuggestedProfile({})).toBe(false) + }) + + it('requests an HttpOnly oauth bind cookie before redirect binding', async () => { + localStorage.setItem('auth_token', 'access-token-value') + const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth') + + await prepareOAuthBindAccessTokenCookie() + + expect(post).toHaveBeenCalledWith('/auth/oauth/bind-token') + }) +}) diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts index 0f663e76ac89457e67b9ca6e483548ec38e527fc..a46c39eb46fb579f5dde903d5ecbfef53c0fef8f 100644 --- a/frontend/src/api/__tests__/client.spec.ts +++ b/frontend/src/api/__tests__/client.spec.ts @@ -91,6 +91,22 @@ describe('API Client', () => { const config = adapter.mock.calls[0][0] expect(config.params?.timezone).toBeUndefined() }) + + it('请求默认带 withCredentials 以支持跨域 cookie', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.post('/auth/oauth/bind-token') + + const config = adapter.mock.calls[0][0] + expect(config.withCredentials).toBe(true) + }) }) // --- 响应拦截器 --- diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..e38fba57fbc6525000d2b7ae9248148ecb031c9a --- /dev/null +++ b/frontend/src/api/__tests__/payment.spec.ts @@ -0,0 +1,40 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { get, post } = vi.hoisted(() => ({ + get: vi.fn(), + post: vi.fn(), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + get, + post, + }, +})) + +import { paymentAPI } from '@/api/payment' + +describe('payment api', () => { + beforeEach(() => { + get.mockReset() + post.mockReset() + get.mockResolvedValue({ data: {} }) + post.mockResolvedValue({ data: {} }) + }) + + it('keeps legacy public out_trade_no verification for upgrade compatibility', async () => { + await paymentAPI.verifyOrderPublic('legacy-order-no') + + expect(post).toHaveBeenCalledWith('/payment/public/orders/verify', { + out_trade_no: 'legacy-order-no', + }) + }) + + it('keeps signed public resume-token resolve endpoint', async () => { + await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123') + + expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', { + resume_token: 'resume-token-123', + }) + }) +}) diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..10f6247ab7eca93114463b0bacbf66add9d80d03 --- /dev/null +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -0,0 +1,131 @@ +import { describe, expect, it } from "vitest"; + +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + type UpdateSettingsRequest, +} from "@/api/admin/settings"; + +describe("admin settings auth source defaults helpers", () => { + it("builds auth source defaults state from flat settings fields", () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_balance: 9.5, + auth_source_default_email_concurrency: 3, + auth_source_default_email_subscriptions: [ + { group_id: 1, validity_days: 30 }, + ], + auth_source_default_email_grant_on_signup: false, + auth_source_default_email_grant_on_first_bind: true, + auth_source_default_linuxdo_balance: 6, + auth_source_default_linuxdo_concurrency: 8, + auth_source_default_linuxdo_subscriptions: [ + { group_id: 2, validity_days: 60 }, + ], + auth_source_default_linuxdo_grant_on_signup: true, + auth_source_default_linuxdo_grant_on_first_bind: false, + }); + + expect(state.email).toEqual({ + balance: 9.5, + concurrency: 3, + subscriptions: [{ group_id: 1, validity_days: 30 }], + grant_on_signup: false, + grant_on_first_bind: true, + }); + expect(state.linuxdo).toEqual({ + balance: 6, + concurrency: 8, + subscriptions: [{ group_id: 2, validity_days: 60 }], + grant_on_signup: true, + grant_on_first_bind: false, + }); + expect(state.oidc).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }); + expect(state.wechat).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }); + }); + + it("defaults grant-on-signup to disabled when settings are missing", () => { + const state = buildAuthSourceDefaultsState({}); + + expect(state.email.grant_on_signup).toBe(false); + expect(state.linuxdo.grant_on_signup).toBe(false); + expect(state.oidc.grant_on_signup).toBe(false); + expect(state.wechat.grant_on_signup).toBe(false); + }); + + it("appends auth source defaults back onto update payload", () => { + const payload: UpdateSettingsRequest = { + site_name: "Sub2API", + }; + + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 1.25, + concurrency: 2, + subscriptions: [{ group_id: 3, validity_days: 7 }], + grant_on_signup: true, + grant_on_first_bind: false, + }, + linuxdo: { + balance: 0, + concurrency: 6, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: true, + }, + oidc: { + balance: 4, + concurrency: 9, + subscriptions: [{ group_id: 9, validity_days: 90 }], + grant_on_signup: true, + grant_on_first_bind: true, + }, + wechat: { + balance: 2, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }, + }); + + expect(payload).toMatchObject({ + site_name: "Sub2API", + auth_source_default_email_balance: 1.25, + auth_source_default_email_concurrency: 2, + auth_source_default_email_subscriptions: [ + { group_id: 3, validity_days: 7 }, + ], + auth_source_default_email_grant_on_signup: true, + auth_source_default_email_grant_on_first_bind: false, + auth_source_default_linuxdo_balance: 0, + auth_source_default_linuxdo_concurrency: 6, + auth_source_default_linuxdo_subscriptions: [], + auth_source_default_linuxdo_grant_on_signup: false, + auth_source_default_linuxdo_grant_on_first_bind: true, + auth_source_default_oidc_balance: 4, + auth_source_default_oidc_concurrency: 9, + auth_source_default_oidc_subscriptions: [ + { group_id: 9, validity_days: 90 }, + ], + auth_source_default_oidc_grant_on_signup: true, + auth_source_default_oidc_grant_on_first_bind: true, + auth_source_default_wechat_balance: 2, + auth_source_default_wechat_concurrency: 5, + auth_source_default_wechat_subscriptions: [], + auth_source_default_wechat_grant_on_signup: false, + auth_source_default_wechat_grant_on_first_bind: false, + }); + }); +}); diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ad355afee372d5c0ee78686c370153efb6e6b407 --- /dev/null +++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts @@ -0,0 +1,63 @@ +import { describe, expect, it } from 'vitest' + +import { + getPaymentVisibleMethodSourceOptions, + normalizePaymentVisibleMethodSource, +} from '@/api/admin/settings' + +describe('admin settings payment visible method helpers', () => { + it('normalizes aliases into canonical source keys per visible method', () => { + expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay') + expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay') + expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay') + + expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay') + expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay') + expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay') + }) + + it('rejects unknown or cross-method source values', () => { + expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('') + expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('') + expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('') + expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('') + }) + + it('exposes method-scoped source options instead of arbitrary strings', () => { + expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([ + { + value: '', + labelZh: '未配置', + labelEn: 'Not configured', + }, + { + value: 'official_alipay', + labelZh: '支付宝官方', + labelEn: 'Official Alipay', + }, + { + value: 'easypay_alipay', + labelZh: '易支付支付宝', + labelEn: 'EasyPay Alipay', + }, + ]) + + expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([ + { + value: '', + labelZh: '未配置', + labelEn: 'Not configured', + }, + { + value: 'official_wxpay', + labelZh: '微信官方', + labelEn: 'Official WeChat Pay', + }, + { + value: 'easypay_wxpay', + labelZh: '易支付微信', + labelEn: 'EasyPay WeChat Pay', + }, + ]) + }) +}) diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..eccb72149b4a28c4f63d5a8fe459a72eacc4e081 --- /dev/null +++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from "vitest"; + +import { + defaultWeChatConnectScopesForMode, + normalizeWeChatConnectMode, +} from "@/api/admin/settings"; + +describe("admin settings wechat connect helpers", () => { + it("normalizes legacy or noisy mode values to the backend contract", () => { + expect(normalizeWeChatConnectMode("OPEN")).toBe("open"); + expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open"); + expect(normalizeWeChatConnectMode("mp")).toBe("mp"); + expect(normalizeWeChatConnectMode("official_account")).toBe("mp"); + expect(normalizeWeChatConnectMode("unknown")).toBe("open"); + }); + + it("maps each mode to the backend default scopes", () => { + expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login"); + expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo"); + }); +}); diff --git a/frontend/src/api/__tests__/user.spec.ts b/frontend/src/api/__tests__/user.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..887046da1842e83b8613e11c4cb01a967f43f571 --- /dev/null +++ b/frontend/src/api/__tests__/user.spec.ts @@ -0,0 +1,32 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +describe('user api oauth binding urls', () => { + beforeEach(() => { + vi.resetModules() + vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/api/v1') + }) + + afterEach(() => { + vi.unstubAllEnvs() + }) + + it('builds third-party bind urls against the bind start endpoint', async () => { + const { buildOAuthBindingStartURL } = await import('@/api/user') + + expect(buildOAuthBindingStartURL('linuxdo', { redirectTo: '/settings/profile' })).toBe( + 'https://api.example.com/api/v1/auth/oauth/linuxdo/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user' + ) + expect( + buildOAuthBindingStartURL('wechat', { + redirectTo: '/settings/profile', + wechatOAuthSettings: { + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + wechat_oauth_mobile_enabled: false + } + }) + ).toBe( + 'https://api.example.com/api/v1/auth/oauth/wechat/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user&mode=open' + ) + }) +}) diff --git a/frontend/src/api/admin/channelMonitor.ts b/frontend/src/api/admin/channelMonitor.ts new file mode 100644 index 0000000000000000000000000000000000000000..949c4bc899941db45678f383b2c032af9a39d568 --- /dev/null +++ b/frontend/src/api/admin/channelMonitor.ts @@ -0,0 +1,202 @@ +/** + * Admin Channel Monitor API endpoints + * Handles channel monitor (uptime/health) management for administrators + */ + +import { apiClient } from '../client' + +export type Provider = 'openai' | 'anthropic' | 'gemini' +export type MonitorStatus = 'operational' | 'degraded' | 'failed' | 'error' +export type BodyOverrideMode = 'off' | 'merge' | 'replace' + +export interface ChannelMonitor { + id: number + name: string + provider: Provider + endpoint: string + api_key_masked: string + /** + * True when the stored encrypted API key cannot be decrypted (e.g. the + * encryption key has changed). Admin must re-edit the monitor to provide + * a fresh key. Backend skips checks for these monitors. + */ + api_key_decrypt_failed?: boolean + primary_model: string + extra_models: string[] + group_name: string + enabled: boolean + interval_seconds: number + last_checked_at: string | null + created_by: number + created_at: string + updated_at: string + /** Latest status of the primary model (empty when no history yet) */ + primary_status: MonitorStatus | '' + /** Latest latency of the primary model in ms (null when no history yet) */ + primary_latency_ms: number | null + /** Primary model 7-day availability percentage (0-100) */ + availability_7d: number + /** Latest status per extra model (used for hover tooltip) */ + extra_models_status: ExtraModelStatus[] + /** 请求自定义快照字段(高级设置) */ + template_id: number | null + extra_headers: Record + body_override_mode: BodyOverrideMode + body_override: Record | null +} + +export interface ExtraModelStatus { + model: string + status: MonitorStatus | '' + latency_ms: number | null +} + +export interface ListParams { + page?: number + page_size?: number + provider?: Provider + enabled?: boolean + search?: string +} + +export interface ListResponse { + items: ChannelMonitor[] + total: number + page: number + page_size: number + pages: number +} + +export interface CreateParams { + name: string + provider: Provider + endpoint: string + api_key: string + primary_model: string + extra_models?: string[] + group_name?: string + enabled?: boolean + interval_seconds: number + template_id?: number | null + extra_headers?: Record + body_override_mode?: BodyOverrideMode + body_override?: Record | null +} + +// Update request: api_key 空串 = 不修改;clear_template=true 时把 template_id 置空 +export type UpdateParams = Partial & { + clear_template?: boolean +} + +export interface CheckResult { + model: string + status: MonitorStatus + latency_ms: number | null + ping_latency_ms: number | null + message: string + checked_at: string +} + +export interface RunNowResponse { + results: CheckResult[] +} + +export interface HistoryItem { + id: number + model: string + status: MonitorStatus + latency_ms: number | null + ping_latency_ms: number | null + message: string + checked_at: string +} + +export interface HistoryParams { + model?: string + limit?: number +} + +export interface HistoryResponse { + items: HistoryItem[] +} + +/** + * List channel monitors with pagination and filters + */ +export async function list( + params: ListParams = {}, + options?: { signal?: AbortSignal } +): Promise { + const { data } = await apiClient.get('/admin/channel-monitors', { + params, + signal: options?.signal, + }) + return data +} + +/** + * Get a channel monitor by ID + */ +export async function get(id: number): Promise { + const { data } = await apiClient.get(`/admin/channel-monitors/${id}`) + return data +} + +/** + * Create a new channel monitor + */ +export async function create(params: CreateParams): Promise { + const { data } = await apiClient.post('/admin/channel-monitors', params) + return data +} + +/** + * Update an existing channel monitor. + * api_key field: empty string means "do not modify". + */ +export async function update(id: number, params: UpdateParams): Promise { + const { data } = await apiClient.put(`/admin/channel-monitors/${id}`, params) + return data +} + +/** + * Delete a channel monitor + */ +export async function del(id: number): Promise { + await apiClient.delete(`/admin/channel-monitors/${id}`) +} + +/** + * Trigger an immediate manual check for a channel monitor. + * Returns the latest check results for primary + extra models. + */ +export async function runNow(id: number): Promise { + const { data } = await apiClient.post(`/admin/channel-monitors/${id}/run`) + return data +} + +/** + * List historical check results for a monitor. + */ +export async function listHistory( + id: number, + params: HistoryParams = {} +): Promise { + const { data } = await apiClient.get( + `/admin/channel-monitors/${id}/history`, + { params } + ) + return data +} + +export const channelMonitorAPI = { + list, + get, + create, + update, + del, + runNow, + listHistory, +} + +export default channelMonitorAPI diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts new file mode 100644 index 0000000000000000000000000000000000000000..01b3c2d05ab5b283ae702d80b541914c2d2f2733 --- /dev/null +++ b/frontend/src/api/admin/channelMonitorTemplate.ts @@ -0,0 +1,132 @@ +/** + * Admin Channel Monitor Request Template API. + * + * 模板 = 一组可复用的 headers + 可选 body 覆盖配置。 + * 应用到监控 = 拷贝快照;模板后续变动不自动同步,需手动点「应用到关联监控」刷新。 + */ + +import { apiClient } from '../client' +import type { BodyOverrideMode, Provider } from './channelMonitor' + +export interface ChannelMonitorTemplate { + id: number + name: string + provider: Provider + description: string + extra_headers: Record + body_override_mode: BodyOverrideMode + body_override: Record | null + created_at: string + updated_at: string + /** 关联的监控数量(快照来自此模板,仅 template_id 匹配即可) */ + associated_monitors: number +} + +export interface ListParams { + provider?: Provider +} + +export interface ListResponse { + items: ChannelMonitorTemplate[] +} + +export interface CreateParams { + name: string + provider: Provider + description?: string + extra_headers?: Record + body_override_mode?: BodyOverrideMode + body_override?: Record | null +} + +export interface UpdateParams { + name?: string + description?: string + extra_headers?: Record + body_override_mode?: BodyOverrideMode + body_override?: Record | null +} + +export interface ApplyResponse { + affected: number +} + +export interface AssociatedMonitorBrief { + id: number + name: string + provider: Provider + enabled: boolean +} + +export interface AssociatedMonitorsResponse { + items: AssociatedMonitorBrief[] +} + +export async function list(params: ListParams = {}): Promise { + const { data } = await apiClient.get('/admin/channel-monitor-templates', { + params, + }) + return data +} + +export async function get(id: number): Promise { + const { data } = await apiClient.get( + `/admin/channel-monitor-templates/${id}`, + ) + return data +} + +export async function create(params: CreateParams): Promise { + const { data } = await apiClient.post( + '/admin/channel-monitor-templates', + params, + ) + return data +} + +export async function update(id: number, params: UpdateParams): Promise { + const { data } = await apiClient.put( + `/admin/channel-monitor-templates/${id}`, + params, + ) + return data +} + +export async function del(id: number): Promise { + await apiClient.delete(`/admin/channel-monitor-templates/${id}`) +} + +/** + * Apply the template to the specified associated monitors (overwrite snapshot fields). + * monitorIds must be a non-empty subset of the template's associated monitors. + * Returns count of actually affected monitors. + */ +export async function apply(id: number, monitorIds: number[]): Promise { + const { data } = await apiClient.post( + `/admin/channel-monitor-templates/${id}/apply`, + { monitor_ids: monitorIds }, + ) + return data +} + +/** + * List monitors currently associated to this template (used by apply picker). + */ +export async function listAssociatedMonitors(id: number): Promise { + const { data } = await apiClient.get( + `/admin/channel-monitor-templates/${id}/monitors`, + ) + return data +} + +export const channelMonitorTemplateAPI = { + list, + get, + create, + update, + del, + apply, + listAssociatedMonitors, +} + +export default channelMonitorTemplateAPI diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index f129ceaabea6c3de7e8dd10c20c7e38e42404eed..9d4301340f7fefa94f6263422acb4be73cca0ecc 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -4,8 +4,9 @@ */ import { apiClient } from '../client' +import type { BillingMode, ChannelStatus, BillingModelSource } from '@/constants/channel' -export type BillingMode = 'token' | 'per_request' | 'image' +export type { BillingMode } from '@/constants/channel' export interface PricingInterval { id?: number @@ -46,8 +47,8 @@ export interface Channel { id: number name: string description: string - status: string - billing_model_source: string // "requested" | "upstream" + status: ChannelStatus + billing_model_source: BillingModelSource restrict_models: boolean features_config?: Record group_ids: number[] diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 8739d5cbe9a8b74f4e0af1f61ec8b30aa33fb21f..6b94b7992549ef06f067b3c3ab45cb26f61a1955 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -164,7 +164,8 @@ export interface GroupRateMultiplierEntry { user_email: string user_notes: string user_status: string - rate_multiplier: number + rate_multiplier?: number | null + rpm_override?: number | null } /** @@ -205,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message: /** * Batch set rate multipliers for users in a group - * @param id - Group ID - * @param entries - Array of { user_id, rate_multiplier } - * @returns Success confirmation + * Only touches rate_multiplier column; preserves rpm_override on existing rows. */ export async function batchSetGroupRateMultipliers( id: number, @@ -220,6 +219,60 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * RPM override entry for a user in a group + */ +export interface GroupRPMOverrideEntry { + user_id: number + user_name: string + user_email: string + user_notes: string + user_status: string + rpm_override: number +} + +/** + * Get RPM overrides for users in a group (subset of rate-multipliers endpoint). + */ +export async function getGroupRPMOverrides(id: number): Promise { + const { data } = await apiClient.get( + `/admin/groups/${id}/rate-multipliers` + ) + return data + .filter(e => e.rpm_override != null) + .map(e => ({ + user_id: e.user_id, + user_name: e.user_name, + user_email: e.user_email, + user_notes: e.user_notes, + user_status: e.user_status, + rpm_override: e.rpm_override as number + })) +} + +/** + * Batch set RPM overrides for users in a group. + * Only touches rpm_override column; preserves rate_multiplier on existing rows. + */ +export async function batchSetGroupRPMOverrides( + id: number, + entries: Array<{ user_id: number; rpm_override: number }> +): Promise<{ message: string }> { + const { data } = await apiClient.put<{ message: string }>( + `/admin/groups/${id}/rpm-overrides`, + { entries } + ) + return data +} + +/** + * Clear all RPM overrides for a group (preserves rate_multiplier). + */ +export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`) + return data +} + /** * Get usage summary (today + cumulative cost) for all groups * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") @@ -262,6 +315,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, + getGroupRPMOverrides, + clearGroupRPMOverrides, + batchSetGroupRPMOverrides, updateSortOrder, getUsageSummary, getCapacitySummary diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 72597365434ff09687adc59155464e103fc0221f..9cda5814cc168cce84a1d2b3c47ba5540638d685 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -26,6 +26,8 @@ import scheduledTestsAPI from './scheduledTests' import backupAPI from './backup' import tlsFingerprintProfileAPI from './tlsFingerprintProfile' import channelsAPI from './channels' +import channelMonitorAPI from './channelMonitor' +import channelMonitorTemplateAPI from './channelMonitorTemplate' import adminPaymentAPI from './payment' /** @@ -55,6 +57,8 @@ export const adminAPI = { backup: backupAPI, tlsFingerprintProfiles: tlsFingerprintProfileAPI, channels: channelsAPI, + channelMonitor: channelMonitorAPI, + channelMonitorTemplate: channelMonitorTemplateAPI, payment: adminPaymentAPI } @@ -82,6 +86,8 @@ export { backupAPI, tlsFingerprintProfileAPI, channelsAPI, + channelMonitorAPI, + channelMonitorTemplateAPI, adminPaymentAPI } diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 1e4a3053092e5870d8004ee433ffe86f80581ed2..b9f246634fb959c71d8307b9f707c24af345f306 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -3,12 +3,293 @@ * Handles system settings management for administrators */ -import { apiClient } from '../client' -import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types' +import { apiClient } from "../client"; +import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types"; export interface DefaultSubscriptionSetting { - group_id: number - validity_days: number + group_id: number; + validity_days: number; +} + +export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat"; + +export interface AuthSourceDefaultsValue { + balance: number; + concurrency: number; + subscriptions: DefaultSubscriptionSetting[]; + grant_on_signup: boolean; + grant_on_first_bind: boolean; +} + +export type AuthSourceDefaultsState = Record< + AuthSourceType, + AuthSourceDefaultsValue +>; +export type PaymentVisibleMethod = "alipay" | "wxpay"; +export type PaymentVisibleMethodSource = + | "" + | "official_alipay" + | "easypay_alipay" + | "official_wxpay" + | "easypay_wxpay"; +export type WeChatConnectMode = "open" | "mp" | "mobile"; + +export interface PaymentVisibleMethodSourceOption { + value: PaymentVisibleMethodSource; + labelZh: string; + labelEn: string; +} + +export interface WeChatConnectModeOption { + value: WeChatConnectMode; + labelZh: string; + labelEn: string; +} + +const AUTH_SOURCE_TYPES: AuthSourceType[] = [ + "email", + "linuxdo", + "oidc", + "wechat", +]; +const AUTH_SOURCE_DEFAULT_BALANCE = 0; +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; +const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record< + PaymentVisibleMethod, + PaymentVisibleMethodSourceOption[] +> = { + alipay: [ + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_alipay", + labelZh: "支付宝官方", + labelEn: "Official Alipay", + }, + { + value: "easypay_alipay", + labelZh: "易支付支付宝", + labelEn: "EasyPay Alipay", + }, + ], + wxpay: [ + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_wxpay", + labelZh: "微信官方", + labelEn: "Official WeChat Pay", + }, + { + value: "easypay_wxpay", + labelZh: "易支付微信", + labelEn: "EasyPay WeChat Pay", + }, + ], +}; +const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record< + PaymentVisibleMethod, + Record +> = { + alipay: { + official_alipay: "official_alipay", + alipay: "official_alipay", + alipay_direct: "official_alipay", + official: "official_alipay", + easypay_alipay: "easypay_alipay", + easypay: "easypay_alipay", + }, + wxpay: { + official_wxpay: "official_wxpay", + wxpay: "official_wxpay", + wxpay_direct: "official_wxpay", + wechat: "official_wxpay", + official: "official_wxpay", + easypay_wxpay: "easypay_wxpay", + easypay: "easypay_wxpay", + }, +}; +const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [ + { value: "open", labelZh: "PC 应用", labelEn: "PC App" }, + { + value: "mp", + labelZh: "公众号", + labelEn: "Official Account", + }, + { + value: "mobile", + labelZh: "移动应用", + labelEn: "Mobile App", + }, +]; +const WECHAT_CONNECT_MODE_ALIASES: Record = { + open: "open", + open_platform: "open", + official: "open", + wx_open: "open", + mp: "mp", + official_account: "mp", + wechat_mp: "mp", + mini_program: "mp", + mobile: "mobile", + mobile_app: "mobile", + native_app: "mobile", +}; + +export function normalizeDefaultSubscriptionSettings( + subscriptions: DefaultSubscriptionSetting[] | null | undefined, +): DefaultSubscriptionSetting[] { + if (!Array.isArray(subscriptions)) return []; + + return subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: Math.floor(item.group_id), + validity_days: Math.min( + 36500, + Math.max(1, Math.floor(item.validity_days)), + ), + })); +} + +export function buildAuthSourceDefaultsState( + settings: Partial, +): AuthSourceDefaultsState { + const raw = settings as Record; + + return AUTH_SOURCE_TYPES.reduce((acc, source) => { + const subscriptions = raw[`auth_source_default_${source}_subscriptions`]; + acc[source] = { + balance: Number( + raw[`auth_source_default_${source}_balance`] ?? + AUTH_SOURCE_DEFAULT_BALANCE, + ), + concurrency: Math.max( + 1, + Number( + raw[`auth_source_default_${source}_concurrency`] ?? + AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ), + subscriptions: normalizeDefaultSubscriptionSettings( + Array.isArray(subscriptions) + ? (subscriptions as DefaultSubscriptionSetting[]) + : [], + ), + grant_on_signup: + raw[`auth_source_default_${source}_grant_on_signup`] === true, + grant_on_first_bind: + raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + }; + return acc; + }, {} as AuthSourceDefaultsState); +} + +export function appendAuthSourceDefaultsToUpdateRequest( + payload: UpdateSettingsRequest, + authSourceDefaults: AuthSourceDefaultsState, +): UpdateSettingsRequest { + const target = payload as Record; + + for (const source of AUTH_SOURCE_TYPES) { + const current = authSourceDefaults[source]; + target[`auth_source_default_${source}_balance`] = + Number(current.balance) || 0; + target[`auth_source_default_${source}_concurrency`] = Math.max( + 1, + Math.floor( + Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ); + target[`auth_source_default_${source}_subscriptions`] = + normalizeDefaultSubscriptionSettings(current.subscriptions); + target[`auth_source_default_${source}_grant_on_signup`] = + current.grant_on_signup; + target[`auth_source_default_${source}_grant_on_first_bind`] = + current.grant_on_first_bind; + } + + return payload; +} + +export function getPaymentVisibleMethodSourceOptions( + method: PaymentVisibleMethod, +): PaymentVisibleMethodSourceOption[] { + return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]; +} + +export function normalizePaymentVisibleMethodSource( + method: PaymentVisibleMethod, + source: unknown, +): PaymentVisibleMethodSource { + if (typeof source !== "string") return ""; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return ""; + + return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ""; +} + +export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] { + return WECHAT_CONNECT_MODE_OPTIONS; +} + +export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode { + if (typeof source !== "string") return "open"; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return "open"; + + return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open"; +} + +export function defaultWeChatConnectScopesForMode(mode: unknown): string { + switch (normalizeWeChatConnectMode(mode)) { + case "mp": + return "snsapi_userinfo"; + case "mobile": + return ""; + default: + return "snsapi_login"; + } +} + +export function resolveWeChatConnectModeCapabilities( + openEnabled: unknown, + mpEnabled: unknown, + mobileEnabled: unknown, + legacyMode: unknown, +): { openEnabled: boolean; mpEnabled: boolean; mobileEnabled: boolean } { + if ( + typeof openEnabled === "boolean" || + typeof mpEnabled === "boolean" || + typeof mobileEnabled === "boolean" + ) { + return { + openEnabled: openEnabled === true, + mpEnabled: mpEnabled === true, + mobileEnabled: mobileEnabled === true, + }; + } + + switch (normalizeWeChatConnectMode(legacyMode)) { + case "mp": + return { openEnabled: false, mpEnabled: true, mobileEnabled: false }; + case "mobile": + return { openEnabled: false, mpEnabled: false, mobileEnabled: true }; + default: + return { openEnabled: true, mpEnabled: false, mobileEnabled: false }; + } +} + +export function deriveWeChatConnectStoredMode( + openEnabled: boolean, + mpEnabled: boolean, + mobileEnabled: boolean, + legacyMode: unknown, +): WeChatConnectMode { + if (mpEnabled) return "mp"; + if (mobileEnabled) return "mobile"; + if (openEnabled) return "open"; + return normalizeWeChatConnectMode(legacyMode); } /** @@ -16,241 +297,343 @@ export interface DefaultSubscriptionSetting { */ export interface SystemSettings { // Registration settings - registration_enabled: boolean - email_verify_enabled: boolean - registration_email_suffix_whitelist: string[] - promo_code_enabled: boolean - password_reset_enabled: boolean - frontend_url: string - invitation_code_enabled: boolean - totp_enabled: boolean // TOTP 双因素认证 - totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 + registration_enabled: boolean; + email_verify_enabled: boolean; + registration_email_suffix_whitelist: string[]; + promo_code_enabled: boolean; + password_reset_enabled: boolean; + frontend_url: string; + invitation_code_enabled: boolean; + totp_enabled: boolean; // TOTP 双因素认证 + totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings - default_balance: number - default_concurrency: number - default_subscriptions: DefaultSubscriptionSetting[] + default_balance: number; + default_concurrency: number; + default_user_rpm_limit: number; + default_subscriptions: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; // OEM settings - site_name: string - site_logo: string - site_subtitle: string - api_base_url: string - contact_info: string - doc_url: string - home_content: string - hide_ccs_import_button: boolean - table_default_page_size: number - table_page_size_options: number[] - backend_mode_enabled: boolean - custom_menu_items: CustomMenuItem[] - custom_endpoints: CustomEndpoint[] + site_name: string; + site_logo: string; + site_subtitle: string; + api_base_url: string; + contact_info: string; + doc_url: string; + home_content: string; + hide_ccs_import_button: boolean; + table_default_page_size: number; + table_page_size_options: number[]; + backend_mode_enabled: boolean; + custom_menu_items: CustomMenuItem[]; + custom_endpoints: CustomEndpoint[]; // SMTP settings - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password_configured: boolean - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password_configured: boolean; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; // Cloudflare Turnstile settings - turnstile_enabled: boolean - turnstile_site_key: string - turnstile_secret_key_configured: boolean + turnstile_enabled: boolean; + turnstile_site_key: string; + turnstile_secret_key_configured: boolean; // LinuxDo Connect OAuth settings - linuxdo_connect_enabled: boolean - linuxdo_connect_client_id: string - linuxdo_connect_client_secret_configured: boolean - linuxdo_connect_redirect_url: string + linuxdo_connect_enabled: boolean; + linuxdo_connect_client_id: string; + linuxdo_connect_client_secret_configured: boolean; + linuxdo_connect_redirect_url: string; + + // WeChat Connect OAuth settings + wechat_connect_enabled: boolean; + wechat_connect_app_id: string; + wechat_connect_app_secret_configured: boolean; + wechat_connect_open_app_id?: string; + wechat_connect_open_app_secret_configured?: boolean; + wechat_connect_mp_app_id?: string; + wechat_connect_mp_app_secret_configured?: boolean; + wechat_connect_mobile_app_id?: string; + wechat_connect_mobile_app_secret_configured?: boolean; + wechat_connect_open_enabled?: boolean; + wechat_connect_mp_enabled?: boolean; + wechat_connect_mobile_enabled?: boolean; + wechat_connect_mode: string; + wechat_connect_scopes: string; + wechat_connect_redirect_url: string; + wechat_connect_frontend_redirect_url: string; // Generic OIDC OAuth settings - oidc_connect_enabled: boolean - oidc_connect_provider_name: string - oidc_connect_client_id: string - oidc_connect_client_secret_configured: boolean - oidc_connect_issuer_url: string - oidc_connect_discovery_url: string - oidc_connect_authorize_url: string - oidc_connect_token_url: string - oidc_connect_userinfo_url: string - oidc_connect_jwks_url: string - oidc_connect_scopes: string - oidc_connect_redirect_url: string - oidc_connect_frontend_redirect_url: string - oidc_connect_token_auth_method: string - oidc_connect_use_pkce: boolean - oidc_connect_validate_id_token: boolean - oidc_connect_allowed_signing_algs: string - oidc_connect_clock_skew_seconds: number - oidc_connect_require_email_verified: boolean - oidc_connect_userinfo_email_path: string - oidc_connect_userinfo_id_path: string - oidc_connect_userinfo_username_path: string + oidc_connect_enabled: boolean; + oidc_connect_provider_name: string; + oidc_connect_client_id: string; + oidc_connect_client_secret_configured: boolean; + oidc_connect_issuer_url: string; + oidc_connect_discovery_url: string; + oidc_connect_authorize_url: string; + oidc_connect_token_url: string; + oidc_connect_userinfo_url: string; + oidc_connect_jwks_url: string; + oidc_connect_scopes: string; + oidc_connect_redirect_url: string; + oidc_connect_frontend_redirect_url: string; + oidc_connect_token_auth_method: string; + oidc_connect_use_pkce: boolean; + oidc_connect_validate_id_token: boolean; + oidc_connect_allowed_signing_algs: string; + oidc_connect_clock_skew_seconds: number; + oidc_connect_require_email_verified: boolean; + oidc_connect_userinfo_email_path: string; + oidc_connect_userinfo_id_path: string; + oidc_connect_userinfo_username_path: string; // Model fallback configuration - enable_model_fallback: boolean - fallback_model_anthropic: string - fallback_model_openai: string - fallback_model_gemini: string - fallback_model_antigravity: string + enable_model_fallback: boolean; + fallback_model_anthropic: string; + fallback_model_openai: string; + fallback_model_gemini: string; + fallback_model_antigravity: string; // Identity patch configuration (Claude -> Gemini) - enable_identity_patch: boolean - identity_patch_prompt: string + enable_identity_patch: boolean; + identity_patch_prompt: string; // Ops Monitoring (vNext) - ops_monitoring_enabled: boolean - ops_realtime_monitoring_enabled: boolean - ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds: number + ops_monitoring_enabled: boolean; + ops_realtime_monitoring_enabled: boolean; + ops_query_mode_default: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds: number; // Claude Code version check - min_claude_code_version: string - max_claude_code_version: string + min_claude_code_version: string; + max_claude_code_version: string; // 分组隔离 - allow_ungrouped_key_scheduling: boolean + allow_ungrouped_key_scheduling: boolean; // Gateway forwarding behavior - enable_fingerprint_unification: boolean - enable_metadata_passthrough: boolean - enable_cch_signing: boolean - web_search_emulation_enabled?: boolean + enable_fingerprint_unification: boolean; + enable_metadata_passthrough: boolean; + enable_cch_signing: boolean; + web_search_emulation_enabled?: boolean; // Payment configuration - payment_enabled: boolean - payment_min_amount: number - payment_max_amount: number - payment_daily_limit: number - payment_order_timeout_minutes: number - payment_max_pending_orders: number - payment_enabled_types: string[] - payment_balance_disabled: boolean - payment_balance_recharge_multiplier: number - payment_recharge_fee_rate: number - payment_load_balance_strategy: string - payment_product_name_prefix: string - payment_product_name_suffix: string - payment_help_image_url: string - payment_help_text: string - payment_cancel_rate_limit_enabled: boolean - payment_cancel_rate_limit_max: number - payment_cancel_rate_limit_window: number - payment_cancel_rate_limit_unit: string - payment_cancel_rate_limit_window_mode: string + payment_enabled: boolean; + payment_min_amount: number; + payment_max_amount: number; + payment_daily_limit: number; + payment_order_timeout_minutes: number; + payment_max_pending_orders: number; + payment_enabled_types: string[]; + payment_balance_disabled: boolean; + payment_balance_recharge_multiplier: number; + payment_recharge_fee_rate: number; + payment_load_balance_strategy: string; + payment_product_name_prefix: string; + payment_product_name_suffix: string; + payment_help_image_url: string; + payment_help_text: string; + payment_cancel_rate_limit_enabled: boolean; + payment_cancel_rate_limit_max: number; + payment_cancel_rate_limit_window: number; + payment_cancel_rate_limit_unit: string; + payment_cancel_rate_limit_window_mode: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled: boolean - balance_low_notify_threshold: number - balance_low_notify_recharge_url: string - account_quota_notify_enabled: boolean - account_quota_notify_emails: NotifyEmailEntry[] + balance_low_notify_enabled: boolean; + balance_low_notify_threshold: number; + balance_low_notify_recharge_url: string; + account_quota_notify_enabled: boolean; + account_quota_notify_emails: NotifyEmailEntry[]; + + // Channel Monitor feature switch + channel_monitor_enabled: boolean; + channel_monitor_default_interval_seconds: number; + + // Available Channels feature switch + available_channels_enabled: boolean; } export interface UpdateSettingsRequest { - registration_enabled?: boolean - email_verify_enabled?: boolean - registration_email_suffix_whitelist?: string[] - promo_code_enabled?: boolean - password_reset_enabled?: boolean - frontend_url?: string - invitation_code_enabled?: boolean - totp_enabled?: boolean // TOTP 双因素认证 - default_balance?: number - default_concurrency?: number - default_subscriptions?: DefaultSubscriptionSetting[] - site_name?: string - site_logo?: string - site_subtitle?: string - api_base_url?: string - contact_info?: string - doc_url?: string - home_content?: string - hide_ccs_import_button?: boolean - table_default_page_size?: number - table_page_size_options?: number[] - backend_mode_enabled?: boolean - custom_menu_items?: CustomMenuItem[] - custom_endpoints?: CustomEndpoint[] - smtp_host?: string - smtp_port?: number - smtp_username?: string - smtp_password?: string - smtp_from_email?: string - smtp_from_name?: string - smtp_use_tls?: boolean - turnstile_enabled?: boolean - turnstile_site_key?: string - turnstile_secret_key?: string - linuxdo_connect_enabled?: boolean - linuxdo_connect_client_id?: string - linuxdo_connect_client_secret?: string - linuxdo_connect_redirect_url?: string - oidc_connect_enabled?: boolean - oidc_connect_provider_name?: string - oidc_connect_client_id?: string - oidc_connect_client_secret?: string - oidc_connect_issuer_url?: string - oidc_connect_discovery_url?: string - oidc_connect_authorize_url?: string - oidc_connect_token_url?: string - oidc_connect_userinfo_url?: string - oidc_connect_jwks_url?: string - oidc_connect_scopes?: string - oidc_connect_redirect_url?: string - oidc_connect_frontend_redirect_url?: string - oidc_connect_token_auth_method?: string - oidc_connect_use_pkce?: boolean - oidc_connect_validate_id_token?: boolean - oidc_connect_allowed_signing_algs?: string - oidc_connect_clock_skew_seconds?: number - oidc_connect_require_email_verified?: boolean - oidc_connect_userinfo_email_path?: string - oidc_connect_userinfo_id_path?: string - oidc_connect_userinfo_username_path?: string - enable_model_fallback?: boolean - fallback_model_anthropic?: string - fallback_model_openai?: string - fallback_model_gemini?: string - fallback_model_antigravity?: string - enable_identity_patch?: boolean - identity_patch_prompt?: string - ops_monitoring_enabled?: boolean - ops_realtime_monitoring_enabled?: boolean - ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds?: number - min_claude_code_version?: string - max_claude_code_version?: string - allow_ungrouped_key_scheduling?: boolean - enable_fingerprint_unification?: boolean - enable_metadata_passthrough?: boolean - enable_cch_signing?: boolean + registration_enabled?: boolean; + email_verify_enabled?: boolean; + registration_email_suffix_whitelist?: string[]; + promo_code_enabled?: boolean; + password_reset_enabled?: boolean; + frontend_url?: string; + invitation_code_enabled?: boolean; + totp_enabled?: boolean; // TOTP 双因素认证 + default_balance?: number; + default_concurrency?: number; + default_user_rpm_limit?: number; + default_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; + site_name?: string; + site_logo?: string; + site_subtitle?: string; + api_base_url?: string; + contact_info?: string; + doc_url?: string; + home_content?: string; + hide_ccs_import_button?: boolean; + table_default_page_size?: number; + table_page_size_options?: number[]; + backend_mode_enabled?: boolean; + custom_menu_items?: CustomMenuItem[]; + custom_endpoints?: CustomEndpoint[]; + smtp_host?: string; + smtp_port?: number; + smtp_username?: string; + smtp_password?: string; + smtp_from_email?: string; + smtp_from_name?: string; + smtp_use_tls?: boolean; + turnstile_enabled?: boolean; + turnstile_site_key?: string; + turnstile_secret_key?: string; + linuxdo_connect_enabled?: boolean; + linuxdo_connect_client_id?: string; + linuxdo_connect_client_secret?: string; + linuxdo_connect_redirect_url?: string; + wechat_connect_enabled?: boolean; + wechat_connect_app_id?: string; + wechat_connect_app_secret?: string; + wechat_connect_open_app_id?: string; + wechat_connect_open_app_secret?: string; + wechat_connect_mp_app_id?: string; + wechat_connect_mp_app_secret?: string; + wechat_connect_mobile_app_id?: string; + wechat_connect_mobile_app_secret?: string; + wechat_connect_open_enabled?: boolean; + wechat_connect_mp_enabled?: boolean; + wechat_connect_mobile_enabled?: boolean; + wechat_connect_mode?: string; + wechat_connect_scopes?: string; + wechat_connect_redirect_url?: string; + wechat_connect_frontend_redirect_url?: string; + oidc_connect_enabled?: boolean; + oidc_connect_provider_name?: string; + oidc_connect_client_id?: string; + oidc_connect_client_secret?: string; + oidc_connect_issuer_url?: string; + oidc_connect_discovery_url?: string; + oidc_connect_authorize_url?: string; + oidc_connect_token_url?: string; + oidc_connect_userinfo_url?: string; + oidc_connect_jwks_url?: string; + oidc_connect_scopes?: string; + oidc_connect_redirect_url?: string; + oidc_connect_frontend_redirect_url?: string; + oidc_connect_token_auth_method?: string; + oidc_connect_use_pkce?: boolean; + oidc_connect_validate_id_token?: boolean; + oidc_connect_allowed_signing_algs?: string; + oidc_connect_clock_skew_seconds?: number; + oidc_connect_require_email_verified?: boolean; + oidc_connect_userinfo_email_path?: string; + oidc_connect_userinfo_id_path?: string; + oidc_connect_userinfo_username_path?: string; + enable_model_fallback?: boolean; + fallback_model_anthropic?: string; + fallback_model_openai?: string; + fallback_model_gemini?: string; + fallback_model_antigravity?: string; + enable_identity_patch?: boolean; + identity_patch_prompt?: string; + ops_monitoring_enabled?: boolean; + ops_realtime_monitoring_enabled?: boolean; + ops_query_mode_default?: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds?: number; + min_claude_code_version?: string; + max_claude_code_version?: string; + allow_ungrouped_key_scheduling?: boolean; + enable_fingerprint_unification?: boolean; + enable_metadata_passthrough?: boolean; + enable_cch_signing?: boolean; // Payment configuration - payment_enabled?: boolean - payment_min_amount?: number - payment_max_amount?: number - payment_daily_limit?: number - payment_order_timeout_minutes?: number - payment_max_pending_orders?: number - payment_enabled_types?: string[] - payment_balance_disabled?: boolean - payment_balance_recharge_multiplier?: number - payment_recharge_fee_rate?: number - payment_load_balance_strategy?: string - payment_product_name_prefix?: string - payment_product_name_suffix?: string - payment_help_image_url?: string - payment_help_text?: string - payment_cancel_rate_limit_enabled?: boolean - payment_cancel_rate_limit_max?: number - payment_cancel_rate_limit_window?: number - payment_cancel_rate_limit_unit?: string - payment_cancel_rate_limit_window_mode?: string + payment_enabled?: boolean; + payment_min_amount?: number; + payment_max_amount?: number; + payment_daily_limit?: number; + payment_order_timeout_minutes?: number; + payment_max_pending_orders?: number; + payment_enabled_types?: string[]; + payment_balance_disabled?: boolean; + payment_balance_recharge_multiplier?: number; + payment_recharge_fee_rate?: number; + payment_load_balance_strategy?: string; + payment_product_name_prefix?: string; + payment_product_name_suffix?: string; + payment_help_image_url?: string; + payment_help_text?: string; + payment_cancel_rate_limit_enabled?: boolean; + payment_cancel_rate_limit_max?: number; + payment_cancel_rate_limit_window?: number; + payment_cancel_rate_limit_unit?: string; + payment_cancel_rate_limit_window_mode?: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled?: boolean - balance_low_notify_threshold?: number - balance_low_notify_recharge_url?: string - account_quota_notify_enabled?: boolean - account_quota_notify_emails?: NotifyEmailEntry[] + balance_low_notify_enabled?: boolean; + balance_low_notify_threshold?: number; + balance_low_notify_recharge_url?: string; + account_quota_notify_enabled?: boolean; + account_quota_notify_emails?: NotifyEmailEntry[]; + + // Channel Monitor feature switch + channel_monitor_enabled?: boolean; + channel_monitor_default_interval_seconds?: number; + + // Available Channels feature switch + available_channels_enabled?: boolean; } /** @@ -258,8 +641,8 @@ export interface UpdateSettingsRequest { * @returns System settings */ export async function getSettings(): Promise { - const { data } = await apiClient.get('/admin/settings') - return data + const { data } = await apiClient.get("/admin/settings"); + return data; } /** @@ -267,20 +650,25 @@ export async function getSettings(): Promise { * @param settings - Partial settings to update * @returns Updated settings */ -export async function updateSettings(settings: UpdateSettingsRequest): Promise { - const { data } = await apiClient.put('/admin/settings', settings) - return data +export async function updateSettings( + settings: UpdateSettingsRequest, +): Promise { + const { data } = await apiClient.put( + "/admin/settings", + settings, + ); + return data; } /** * Test SMTP connection request */ export interface TestSmtpRequest { - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_use_tls: boolean; } /** @@ -288,23 +676,28 @@ export interface TestSmtpRequest { * @param config - SMTP configuration to test * @returns Test result message */ -export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> { - const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config) - return data +export async function testSmtpConnection( + config: TestSmtpRequest, +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>( + "/admin/settings/test-smtp", + config, + ); + return data; } /** * Send test email request */ export interface SendTestEmailRequest { - email: string - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + email: string; + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; } /** @@ -312,20 +705,22 @@ export interface SendTestEmailRequest { * @param request - Email address and SMTP config * @returns Test result message */ -export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> { +export async function sendTestEmail( + request: SendTestEmailRequest, +): Promise<{ message: string }> { const { data } = await apiClient.post<{ message: string }>( - '/admin/settings/send-test-email', - request - ) - return data + "/admin/settings/send-test-email", + request, + ); + return data; } /** * Admin API Key status response */ export interface AdminApiKeyStatus { - exists: boolean - masked_key: string + exists: boolean; + masked_key: string; } /** @@ -333,8 +728,10 @@ export interface AdminApiKeyStatus { * @returns Status indicating if key exists and masked version */ export async function getAdminApiKey(): Promise { - const { data } = await apiClient.get('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.get( + "/admin/settings/admin-api-key", + ); + return data; } /** @@ -342,8 +739,10 @@ export async function getAdminApiKey(): Promise { * @returns The new full API key (only shown once) */ export async function regenerateAdminApiKey(): Promise<{ key: string }> { - const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate') - return data + const { data } = await apiClient.post<{ key: string }>( + "/admin/settings/admin-api-key/regenerate", + ); + return data; } /** @@ -351,8 +750,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> { * @returns Success message */ export async function deleteAdminApiKey(): Promise<{ message: string }> { - const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.delete<{ message: string }>( + "/admin/settings/admin-api-key", + ); + return data; } // ==================== Overload Cooldown Settings ==================== @@ -361,23 +762,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { * Overload cooldown settings interface (529 handling) */ export interface OverloadCooldownSettings { - enabled: boolean - cooldown_minutes: number + enabled: boolean; + cooldown_minutes: number; } export async function getOverloadCooldownSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/overload-cooldown') - return data + const { data } = await apiClient.get( + "/admin/settings/overload-cooldown", + ); + return data; } export async function updateOverloadCooldownSettings( - settings: OverloadCooldownSettings + settings: OverloadCooldownSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/overload-cooldown', - settings - ) - return data + "/admin/settings/overload-cooldown", + settings, + ); + return data; } // ==================== Stream Timeout Settings ==================== @@ -386,11 +789,11 @@ export async function updateOverloadCooldownSettings( * Stream timeout settings interface */ export interface StreamTimeoutSettings { - enabled: boolean - action: 'temp_unsched' | 'error' | 'none' - temp_unsched_minutes: number - threshold_count: number - threshold_window_minutes: number + enabled: boolean; + action: "temp_unsched" | "error" | "none"; + temp_unsched_minutes: number; + threshold_count: number; + threshold_window_minutes: number; } /** @@ -398,8 +801,10 @@ export interface StreamTimeoutSettings { * @returns Stream timeout settings */ export async function getStreamTimeoutSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/stream-timeout') - return data + const { data } = await apiClient.get( + "/admin/settings/stream-timeout", + ); + return data; } /** @@ -408,13 +813,13 @@ export async function getStreamTimeoutSettings(): Promise * @returns Updated settings */ export async function updateStreamTimeoutSettings( - settings: StreamTimeoutSettings + settings: StreamTimeoutSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/stream-timeout', - settings - ) - return data + "/admin/settings/stream-timeout", + settings, + ); + return data; } // ==================== Rectifier Settings ==================== @@ -423,11 +828,11 @@ export async function updateStreamTimeoutSettings( * Rectifier settings interface */ export interface RectifierSettings { - enabled: boolean - thinking_signature_enabled: boolean - thinking_budget_enabled: boolean - apikey_signature_enabled: boolean - apikey_signature_patterns: string[] + enabled: boolean; + thinking_signature_enabled: boolean; + thinking_budget_enabled: boolean; + apikey_signature_enabled: boolean; + apikey_signature_patterns: string[]; } /** @@ -435,8 +840,10 @@ export interface RectifierSettings { * @returns Rectifier settings */ export async function getRectifierSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/rectifier') - return data + const { data } = await apiClient.get( + "/admin/settings/rectifier", + ); + return data; } /** @@ -445,13 +852,13 @@ export async function getRectifierSettings(): Promise { * @returns Updated settings */ export async function updateRectifierSettings( - settings: RectifierSettings + settings: RectifierSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/rectifier', - settings - ) - return data + "/admin/settings/rectifier", + settings, + ); + return data; } // ==================== Beta Policy Settings ==================== @@ -460,20 +867,20 @@ export async function updateRectifierSettings( * Beta policy rule interface */ export interface BetaPolicyRule { - beta_token: string - action: 'pass' | 'filter' | 'block' - scope: 'all' | 'oauth' | 'apikey' | 'bedrock' - error_message?: string - model_whitelist?: string[] - fallback_action?: 'pass' | 'filter' | 'block' - fallback_error_message?: string + beta_token: string; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; } /** * Beta policy settings interface */ export interface BetaPolicySettings { - rules: BetaPolicyRule[] + rules: BetaPolicyRule[]; } /** @@ -481,8 +888,10 @@ export interface BetaPolicySettings { * @returns Beta policy settings */ export async function getBetaPolicySettings(): Promise { - const { data } = await apiClient.get('/admin/settings/beta-policy') - return data + const { data } = await apiClient.get( + "/admin/settings/beta-policy", + ); + return data; } /** @@ -491,70 +900,73 @@ export async function getBetaPolicySettings(): Promise { * @returns Updated settings */ export async function updateBetaPolicySettings( - settings: BetaPolicySettings + settings: BetaPolicySettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/beta-policy', - settings - ) - return data + "/admin/settings/beta-policy", + settings, + ); + return data; } // --- Web Search Emulation Config --- export interface WebSearchProviderConfig { - type: 'brave' | 'tavily' - api_key: string - api_key_configured: boolean - quota_limit: number | null - subscribed_at: number | null - quota_used?: number - proxy_id: number | null - expires_at: number | null + type: "brave" | "tavily"; + api_key: string; + api_key_configured: boolean; + quota_limit: number | null; + subscribed_at: number | null; + quota_used?: number; + proxy_id: number | null; + expires_at: number | null; } export interface WebSearchEmulationConfig { - enabled: boolean - providers: WebSearchProviderConfig[] + enabled: boolean; + providers: WebSearchProviderConfig[]; } export interface WebSearchTestResult { - provider: string - results: { url: string; title: string; snippet: string; page_age?: string }[] - query: string + provider: string; + results: { url: string; title: string; snippet: string; page_age?: string }[]; + query: string; } export async function getWebSearchEmulationConfig(): Promise { const { data } = await apiClient.get( - '/admin/settings/web-search-emulation' - ) - return data + "/admin/settings/web-search-emulation", + ); + return data; } export async function updateWebSearchEmulationConfig( - config: WebSearchEmulationConfig + config: WebSearchEmulationConfig, ): Promise { const { data } = await apiClient.put( - '/admin/settings/web-search-emulation', - config - ) - return data + "/admin/settings/web-search-emulation", + config, + ); + return data; } export async function testWebSearchEmulation( - query: string + query: string, ): Promise { const { data } = await apiClient.post( - '/admin/settings/web-search-emulation/test', - { query } - ) - return data + "/admin/settings/web-search-emulation/test", + { query }, + ); + return data; } -export async function resetWebSearchUsage( - payload: { provider_type: string } -): Promise { - await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +export async function resetWebSearchUsage(payload: { + provider_type: string; +}): Promise { + await apiClient.post( + "/admin/settings/web-search-emulation/reset-usage", + payload, + ); } export const settingsAPI = { @@ -576,7 +988,7 @@ export const settingsAPI = { getWebSearchEmulationConfig, updateWebSearchEmulationConfig, testWebSearchEmulation, - resetWebSearchUsage -} + resetWebSearchUsage, +}; -export default settingsAPI +export default settingsAPI; diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 39cb1dfa69217d7492e592d955cc5d7f2eb2aa63..3c75a6c4f5a7f4febe14be861ddd7e3a60e8eaff 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -6,6 +6,44 @@ import { apiClient } from '../client' import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types' +export interface AdminBindAuthIdentityChannelRequest { + channel: string + channel_app_id: string + channel_subject: string + metadata?: Record | null +} + +export interface AdminBindAuthIdentityRequest { + provider_type: string + provider_key: string + provider_subject: string + issuer?: string | null + metadata?: Record | null + channel?: AdminBindAuthIdentityChannelRequest +} + +export interface AdminBoundAuthIdentityChannel { + channel: string + channel_app_id: string + channel_subject: string + metadata: Record | null + created_at: string + updated_at: string +} + +export interface AdminBoundAuthIdentity { + user_id: number + provider_type: string + provider_key: string + provider_subject: string + verified_at?: string | null + issuer?: string | null + metadata: Record | null + created_at: string + updated_at: string + channel?: AdminBoundAuthIdentityChannel | null +} + /** * List all users with pagination * @param page - Page number (default: 1) @@ -248,6 +286,17 @@ export async function replaceGroup( return data } +export async function bindUserAuthIdentity( + userId: number, + input: AdminBindAuthIdentityRequest +): Promise { + const { data } = await apiClient.post( + `/admin/users/${userId}/auth-identities`, + input + ) + return data +} + export const usersAPI = { list, getById, @@ -260,7 +309,8 @@ export const usersAPI = { getUserApiKeys, getUserUsageStats, getUserBalanceHistory, - replaceGroup + replaceGroup, + bindUserAuthIdentity } export default usersAPI diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 837c4f4cf7068fe1ce394f49ad83131d346cff01..f49f3a1f843bcf9e5a983b8c4830d23a0178ab51 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,6 +186,108 @@ export interface RefreshTokenResponse { token_type: string } +export interface OAuthTokenResponse { + access_token: string + refresh_token?: string + expires_in?: number + token_type?: string +} + +export interface PendingOAuthBindLoginResponse extends Partial { + auth_result?: string + redirect?: string + error?: string + requires_2fa?: boolean + temp_token?: string + user_email_masked?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +} + +export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse + +export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse { + auth_result?: string +} + +export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse { + auth_result?: string + provider?: string + redirect?: string +} + +export type OAuthCompletionKind = 'login' | 'bind' + +export interface OAuthAdoptionDecision { + adoptDisplayName?: boolean + adoptAvatar?: boolean +} + +function serializeOAuthAdoptionDecision( + decision?: OAuthAdoptionDecision +): Record { + const payload: Record = {} + + if (typeof decision?.adoptDisplayName === 'boolean') { + payload.adopt_display_name = decision.adoptDisplayName + } + if (typeof decision?.adoptAvatar === 'boolean') { + payload.adopt_avatar = decision.adoptAvatar + } + + return payload +} + +export function isOAuthLoginCompletion( + completion: Partial +): completion is OAuthTokenResponse { + return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0 +} + +export function getOAuthCompletionKind( + completion: Partial +): OAuthCompletionKind { + return isOAuthLoginCompletion(completion) ? 'login' : 'bind' +} + +export function getPendingOAuthBindLoginKind( + completion: PendingOAuthBindLoginResponse +): OAuthCompletionKind { + return getOAuthCompletionKind(completion) +} + +export function isPendingOAuthCreateAccountRequired( + completion: Pick +): boolean { + return completion.error === 'invitation_required' +} + +export function hasPendingOAuthSuggestedProfile( + completion: Pick< + PendingOAuthBindLoginResponse, + 'suggested_display_name' | 'suggested_avatar_url' + > +): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +export function persistOAuthTokenContext(tokens: Partial): void { + if (tokens.refresh_token) { + setRefreshToken(tokens.refresh_token) + } + if (tokens.expires_in) { + setTokenExpiresAt(tokens.expires_in) + } +} + +export async function prepareOAuthBindAccessTokenCookie(): Promise { + if (!getAuthToken()) { + return + } + await apiClient.post('/auth/oauth/bind-token') +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -234,6 +336,116 @@ export async function getPublicSettings(): Promise { return data } +export type WeChatOAuthMode = 'open' | 'mp' +export type WeChatOAuthUnavailableReason = + | 'not_configured' + | 'capability_unknown' + | 'external_browser_required' + | 'wechat_browser_required' + | 'native_app_required' + +export interface ResolvedWeChatOAuthStart { + mode: WeChatOAuthMode | null + openEnabled: boolean + mpEnabled: boolean + mobileEnabled: boolean + isWeChatBrowser: boolean + unavailableReason: WeChatOAuthUnavailableReason | null +} + +export type WeChatOAuthPublicSettings = { + wechat_oauth_enabled?: boolean + wechat_oauth_open_enabled?: boolean + wechat_oauth_mp_enabled?: boolean + wechat_oauth_mobile_enabled?: boolean +} + +export function isWeChatWebOAuthEnabled( + settings: WeChatOAuthPublicSettings | null | undefined, +): boolean { + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const hasExplicitCapabilities = + typeof settings?.wechat_oauth_open_enabled === 'boolean' || + typeof settings?.wechat_oauth_mp_enabled === 'boolean' + + if (!hasExplicitCapabilities) { + return legacyEnabled + } + + return settings?.wechat_oauth_open_enabled === true || settings?.wechat_oauth_mp_enabled === true +} + +export function hasExplicitWeChatOAuthCapabilities( + settings: WeChatOAuthPublicSettings | null | undefined, +): settings is WeChatOAuthPublicSettings & { + wechat_oauth_open_enabled: boolean + wechat_oauth_mp_enabled: boolean +} { + return typeof settings?.wechat_oauth_open_enabled === 'boolean' + && typeof settings?.wechat_oauth_mp_enabled === 'boolean' +} + +export function resolveWeChatOAuthStart( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean' + ? settings.wechat_oauth_open_enabled + : legacyEnabled + const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean' + ? settings.wechat_oauth_mp_enabled + : legacyEnabled + const mobileEnabled = typeof settings?.wechat_oauth_mobile_enabled === 'boolean' + ? settings.wechat_oauth_mobile_enabled + : false + + if (isWeChatBrowser) { + if (mpEnabled) { + return { mode: 'mp', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null } + } + if (openEnabled) { + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } + } + + if (openEnabled) { + return { mode: 'open', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null } + } + if (mpEnabled) { + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } +} + +export function resolveWeChatOAuthStartStrict( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string, +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + + if (!hasExplicitWeChatOAuthCapabilities(settings)) { + return { + mode: null, + openEnabled: false, + mpEnabled: false, + mobileEnabled: false, + isWeChatBrowser, + unavailableReason: 'capability_unknown', + } + } + + return resolveWeChatOAuthStart(settings, normalizedUserAgent) +} + /** * Send verification code to email * @param request - Email and optional Turnstile token @@ -246,6 +458,16 @@ export async function sendVerifyCode( return data } +export async function sendPendingOAuthVerifyCode( + request: SendVerifyCodeRequest +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/send-verify-code', + request + ) + return data +} + /** * Validate promo code response */ @@ -337,48 +559,87 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/linuxdo/complete-registration', { - pending_oauth_token: pendingOAuthToken, - invitation_code: invitationCode - }) - return data + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingLinuxDoOAuthAccount(invitationCode, decision) } /** * Complete OIDC OAuth registration by supplying an invitation code - * @param pendingOAuthToken - Short-lived JWT from the OAuth callback * @param invitationCode - Invitation code entered by the user * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - pendingOAuthToken: string, - invitationCode: string -): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { - const { data } = await apiClient.post<{ - access_token: string - refresh_token: string - expires_in: number - token_type: string - }>('/auth/oauth/oidc/complete-registration', { - pending_oauth_token: pendingOAuthToken, - invitation_code: invitationCode - }) + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOIDCOAuthAccount(invitationCode, decision) +} + +export async function completeWeChatOAuthRegistration( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingWeChatOAuthAccount(invitationCode, decision) +} + +async function createPendingOAuthAccount( + provider: 'linuxdo' | 'oidc' | 'wechat', + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + `/auth/oauth/${provider}/complete-registration`, + { + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) + } + ) return data } +export async function createPendingLinuxDoOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('linuxdo', invitationCode, decision) +} + +export async function createPendingOIDCOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('oidc', invitationCode, decision) +} + +export async function createPendingWeChatOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('wechat', invitationCode, decision) +} + +export async function completePendingOAuthBindLogin( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) + return data +} + +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + return completePendingOAuthBindLogin(decision) +} + export const authAPI = { login, login2FA, @@ -396,14 +657,24 @@ export const authAPI = { clearAuthToken, getPublicSettings, sendVerifyCode, + sendPendingOAuthVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, resetPassword, refreshToken, revokeAllSessions, + getPendingOAuthBindLoginKind, + isPendingOAuthCreateAccountRequired, + hasPendingOAuthSuggestedProfile, + completePendingOAuthBindLogin, + createPendingLinuxDoOAuthAccount, + createPendingOIDCOAuthAccount, + createPendingWeChatOAuthAccount, + exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, - completeOIDCOAuthRegistration + completeOIDCOAuthRegistration, + completeWeChatOAuthRegistration } export default authAPI diff --git a/frontend/src/api/channelMonitor.ts b/frontend/src/api/channelMonitor.ts new file mode 100644 index 0000000000000000000000000000000000000000..38dd0c99a8babc56143a719c7ed876efa997519a --- /dev/null +++ b/frontend/src/api/channelMonitor.ts @@ -0,0 +1,83 @@ +/** + * User-facing Channel Monitor API endpoints + * Read-only views for end users to inspect channel availability/status. + */ + +import { apiClient } from './client' +import type { Provider, MonitorStatus } from './admin/channelMonitor' + +export type { Provider, MonitorStatus } from './admin/channelMonitor' + +export interface UserMonitorExtraModel { + model: string + status: MonitorStatus + latency_ms: number | null +} + +export interface MonitorTimelinePoint { + status: MonitorStatus + latency_ms: number | null + ping_latency_ms: number | null + checked_at: string +} + +export interface UserMonitorView { + id: number + name: string + provider: Provider + group_name: string + primary_model: string + primary_status: MonitorStatus + primary_latency_ms: number | null + primary_ping_latency_ms: number | null + availability_7d: number + extra_models: UserMonitorExtraModel[] + timeline: MonitorTimelinePoint[] +} + +export interface UserMonitorListResponse { + items: UserMonitorView[] +} + +export interface UserMonitorModelDetail { + model: string + latest_status: MonitorStatus + latest_latency_ms: number | null + availability_7d: number + availability_15d: number + availability_30d: number + avg_latency_7d_ms: number | null +} + +export interface UserMonitorDetail { + id: number + name: string + provider: Provider + group_name: string + models: UserMonitorModelDetail[] +} + +/** + * List all monitor views available to the current user. + */ +export async function list(options?: { signal?: AbortSignal }): Promise { + const { data } = await apiClient.get('/channel-monitors', { + signal: options?.signal, + }) + return data +} + +/** + * Get detailed status (multi-window availability + latency) for a single monitor. + */ +export async function status(id: number): Promise { + const { data } = await apiClient.get(`/channel-monitors/${id}/status`) + return data +} + +export const channelMonitorUserAPI = { + list, + status, +} + +export default channelMonitorUserAPI diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts new file mode 100644 index 0000000000000000000000000000000000000000..8962af2c4d84355957820dfc5e6056dbc0b74275 --- /dev/null +++ b/frontend/src/api/channels.ts @@ -0,0 +1,76 @@ +/** + * User Channels API endpoints (non-admin) + * 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。 + */ + +import { apiClient } from './client' +import type { BillingMode } from '@/constants/channel' + +export interface UserAvailableGroup { + id: number + name: string + platform: string + /** 'standard' | 'subscription' — 订阅分组视觉加深,和 API 密钥页保持一致。 */ + subscription_type: string + /** 分组默认倍率。用户专属倍率(若有)通过 /groups/rates 获取后在前端 join。 */ + rate_multiplier: number + /** true = 专属分组(小范围授权);false = 公开分组。 */ + is_exclusive: boolean +} + +export interface UserPricingInterval { + min_tokens: number + max_tokens: number | null + tier_label?: string + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + per_request_price: number | null +} + +export interface UserSupportedModelPricing { + billing_mode: BillingMode + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + image_output_price: number | null + per_request_price: number | null + intervals: UserPricingInterval[] +} + +export interface UserSupportedModel { + name: string + platform: string + pricing: UserSupportedModelPricing | null +} + +/** + * 渠道下单个平台的子视图:用户可访问的分组 + 该平台支持的模型。 + * 后端把一个渠道按平台聚合成 sections,前端可以把渠道名作为 row-group + * 一次渲染,后面按 sections 顺序用 rowspan 铺开。 + */ +export interface UserChannelPlatformSection { + platform: string + groups: UserAvailableGroup[] + supported_models: UserSupportedModel[] +} + +export interface UserAvailableChannel { + name: string + description: string + platforms: UserChannelPlatformSection[] +} + +/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */ +export async function getAvailable(options?: { signal?: AbortSignal }): Promise { + const { data } = await apiClient.get('/channels/available', { + signal: options?.signal + }) + return data +} + +export const userChannelsAPI = { getAvailable } + +export default userChannelsAPI diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 8a586902741016ff3f94ae727c9c98fc35fdce8c..54ea4520097974e24f29cf7f7226dcb50e25f4f7 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -13,6 +13,7 @@ const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api/v1' export const apiClient: AxiosInstance = axios.create({ baseURL: API_BASE_URL, + withCredentials: true, timeout: 30000, headers: { 'Content-Type': 'application/json' diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 3b38eaa59762050b73bb01b3289771fbdb0c650c..f0b82b15f21acbf3d73fb234d1cdbf718c0b13ab 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -16,8 +16,10 @@ export { userAPI } from './user' export { redeemAPI, type RedeemHistoryItem } from './redeem' export { paymentAPI } from './payment' export { userGroupsAPI } from './groups' +export { userChannelsAPI } from './channels' export { totpAPI } from './totp' export { default as announcementsAPI } from './announcements' +export { channelMonitorUserAPI } from './channelMonitor' // Admin APIs export { adminAPI } from './admin' diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 5cedb107ec77e14f92c86b2f605e2353fa9862bc..92b0ec90c6a2102f95a451c1d5225a341939b924 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -67,11 +67,16 @@ export const paymentAPI = { return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo }) }, - /** Verify order payment status without auth (public endpoint for result page) */ + /** Legacy-compatible public order lookup by out_trade_no */ verifyOrderPublic(outTradeNo: string) { return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo }) }, + /** Resolve an order from a signed resume token without auth */ + resolveOrderPublicByResumeToken(resumeToken: string) { + return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken }) + }, + /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index cd6482708f3faf3d269dba0dfbc3f7899edf2097..fd3cedb9a1460d8d1269f50205f78f6572661991 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -4,7 +4,12 @@ */ import { apiClient } from './client' -import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types' +import { + resolveWeChatOAuthStartStrict, + prepareOAuthBindAccessTokenCookie, + type WeChatOAuthPublicSettings, +} from './auth' +import type { User, ChangePasswordRequest, NotifyEmailEntry, UserAuthProvider } from '@/types' /** * Get current user profile @@ -22,6 +27,7 @@ export async function getProfile(): Promise { */ export async function updateProfile(profile: { username?: string + avatar_url?: string | null balance_notify_enabled?: boolean balance_notify_threshold?: number | null balance_notify_extra_emails?: NotifyEmailEntry[] @@ -83,6 +89,85 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi return data } +export async function sendEmailBindingCode(email: string): Promise { + await apiClient.post('/user/account-bindings/email/send-code', { email }) +} + +export async function bindEmailIdentity(payload: { + email: string + verify_code: string + password: string +}): Promise { + const { data } = await apiClient.post('/user/account-bindings/email', payload) + return data +} + +export async function unbindAuthIdentity(provider: BindableOAuthProvider): Promise { + const { data } = await apiClient.delete(`/user/account-bindings/${provider}`) + return data +} + +export type BindableOAuthProvider = Exclude + +interface BuildOAuthBindingStartURLOptions { + redirectTo?: string + wechatOAuthSettings?: WeChatOAuthPublicSettings | null +} + +export function resolveWeChatOAuthMode(): 'open' | 'mp' { + if (typeof navigator === 'undefined') { + return 'open' + } + return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open' +} + +function resolveWeChatOAuthBindingMode( + settings?: WeChatOAuthPublicSettings | null +): 'open' | 'mp' | null { + if (settings) { + return resolveWeChatOAuthStartStrict(settings).mode + } + return resolveWeChatOAuthMode() +} + +export function buildOAuthBindingStartURL( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): string | null { + const redirectTo = options.redirectTo?.trim() || '/profile' + const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1' + const normalized = apiBase.replace(/\/$/, '') + const params = new URLSearchParams({ + redirect: redirectTo, + intent: 'bind_current_user' + }) + + if (provider === 'wechat') { + const mode = resolveWeChatOAuthBindingMode(options.wechatOAuthSettings) + if (!mode) { + return null + } + params.set('mode', mode) + } + + return `${normalized}/auth/oauth/${provider}/bind/start?${params.toString()}` +} + +export async function startOAuthBinding( + provider: BindableOAuthProvider, + options: BuildOAuthBindingStartURLOptions = {} +): Promise { + if (typeof window === 'undefined') { + return + } + const startURL = buildOAuthBindingStartURL(provider, options) + if (!startURL) { + return + } + await prepareOAuthBindAccessTokenCookie() + window.location.href = startURL +} + export const userAPI = { getProfile, updateProfile, @@ -90,7 +175,12 @@ export const userAPI = { sendNotifyEmailCode, verifyNotifyEmail, removeNotifyEmail, - toggleNotifyEmail + toggleNotifyEmail, + sendEmailBindingCode, + bindEmailIdentity, + unbindAuthIdentity, + buildOAuthBindingStartURL, + startOAuthBinding } export default userAPI diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index fc2f7d0c88eaf662b0edd7c9f5ab037b86af18ec..dd38a49f9d23a8a873e64c884fabdb0b3869ea6b 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -284,6 +284,16 @@ const hasError = computed(() => { return props.account.status === 'error' }) +const isQuotaExceeded = computed(() => { + const exceeded = (used?: number | null, limit?: number | null) => + typeof limit === 'number' && limit > 0 && typeof used === 'number' && used >= limit + return ( + exceeded(props.account.quota_used, props.account.quota_limit) || + exceeded(props.account.quota_daily_used, props.account.quota_daily_limit) || + exceeded(props.account.quota_weekly_used, props.account.quota_weekly_limit) + ) +}) + // Computed: countdown text for rate limit (429) const rateLimitCountdown = computed(() => { return formatCountdown(props.account.rate_limit_reset_at) @@ -307,19 +317,16 @@ const statusClass = computed(() => { if (isTempUnschedulable.value) { return 'badge-warning' } + if (props.account.status !== 'active') { + return props.account.status === 'error' ? 'badge-danger' : 'badge-gray' + } + if (isQuotaExceeded.value) { + return 'badge-warning' + } if (!props.account.schedulable) { return 'badge-gray' } - switch (props.account.status) { - case 'active': - return 'badge-success' - case 'inactive': - return 'badge-gray' - case 'error': - return 'badge-danger' - default: - return 'badge-gray' - } + return 'badge-success' }) // Computed: status text @@ -330,6 +337,12 @@ const statusText = computed(() => { if (isTempUnschedulable.value) { return t('admin.accounts.status.tempUnschedulable') } + if (props.account.status !== 'active') { + return t(`admin.accounts.status.${props.account.status}`) + } + if (isQuotaExceeded.value) { + return t('admin.accounts.status.quotaExceeded') + } if (!props.account.schedulable) { return t('admin.accounts.status.paused') } diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 67409a7cd30385ce3629bf1765d1ca0cb2611d57..2e3db61bfdc376c78756756fe175cfec1cf3eaf2 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -55,12 +55,12 @@ /> -
+
+

+ {{ t("admin.settings.site.homeContentHint") }} +

+ +

+ {{ t("admin.settings.site.homeContentIframeWarning") }} +

-
- {{ t('admin.settings.webSearchEmulation.noProviders') }} + +
+
+ +

+ {{ t("admin.settings.site.hideCcsImportButtonHint") }} +

+
+
+
+
-
- -
-
- +
+
+

+ {{ t("admin.settings.customMenu.title") }} +

+

+ {{ t("admin.settings.customMenu.description") }} +

+
+
+ +
+
+ + {{ + t("admin.settings.customMenu.itemLabel", { n: index + 1 }) + }} + +
+ + - -
-
+ +
- -
-
- - -

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

-
-
- - -

{{ t('admin.settings.webSearchEmulation.subscribedAtHint') }}

-
+ +
+ +
- -
- {{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
-
-
-
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} - + {{ t("admin.settings.customMenu.url") }} + +
- -
-
- - -
- + {{ t("admin.settings.customMenu.iconSvg") }} + +
-
-
-
- -
-
-

- {{ t('admin.settings.webSearchEmulation.testResultTitle') }} -

-
- + -
- -
-

- {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} -

-
- {{ t('admin.settings.webSearchEmulation.testNoResults') }} -
-
- {{ r.title }} -

{{ r.snippet }}

-
-
-
-
+ -
+ +
- -
-

- {{ t('admin.settings.site.title') }} + {{ t('admin.settings.features.channelMonitor.title') }}

- {{ t('admin.settings.site.description') }} + {{ t('admin.settings.features.channelMonitor.description') }} +

+

+ + {{ t('admin.settings.features.channelMonitor.configureLink') }} + +

-
- -
-
-

- {{ t('admin.settings.site.backendMode') }} -

-

- {{ t('admin.settings.site.backendModeDescription') }} -

-
- -
- -
-
- - -

- {{ t('admin.settings.site.siteNameHint') }} -

-
+
+
-
+
- -
-
+
- -
-

- {{ t('admin.settings.site.tablePreferencesTitle') }} -

-

- {{ t('admin.settings.site.tablePreferencesDescription') }} -

-
-
- - -

- {{ t('admin.settings.site.tableDefaultPageSizeHint') }} -

-
-
- - -

- {{ t('admin.settings.site.tablePageSizeOptionsHint') }} -

-
+
+
+

+ {{ t('admin.settings.features.availableChannels.title') }} +

+

+ {{ t('admin.settings.features.availableChannels.description') }} +

+

+ + {{ t('admin.settings.features.availableChannels.configureLink') }} + + +

+
+
+
+
+ +

+ {{ t('admin.settings.features.availableChannels.enabledHint') }} +

+
+
+
- -
- -

- {{ t('admin.settings.site.customEndpoints.description') }} -

+
-
-
+ +
+ +
+
+

+ {{ t("admin.settings.payment.title") }} +

+

+ {{ t("admin.settings.payment.description") }} + -

+
+ +
+
+ +

+ {{ t("admin.settings.payment.enabledHint") }} +

+
+ +
+
+
- -
-
- -

- {{ t('admin.settings.site.hideCcsImportButtonHint') }} -

+ + +
+ +
+ +
+
+
+ +
+

+ {{ t("admin.settings.emailTabDisabledTitle") }} +

+

+ {{ t("admin.settings.emailTabDisabledHint") }} +

+
-
-
- - -
-
-

- {{ t('admin.settings.customMenu.title') }} -

-

- {{ t('admin.settings.customMenu.description') }} -

-
-
- + +
-
- - {{ t('admin.settings.customMenu.itemLabel', { n: index + 1 }) }} - -
- - - - - - -
+
+

+ {{ t("admin.settings.smtp.title") }} +

+

+ {{ t("admin.settings.smtp.description") }} +

- -
- + +
+
+
-
- -
- - -
- - -
-
- - -
-
-
- - - -
-
- -
- - - -
- - -
-
-

{{ t('admin.settings.payment.title') }}

-

- {{ t('admin.settings.payment.description') }} - - - {{ t('admin.settings.payment.configGuide') }} - -

-
-
- -
-
- -

{{ t('admin.settings.payment.enabledHint') }}

-
- -
- -
-
- - - - -
- -
- -
-
-
- -
-

- {{ t('admin.settings.emailTabDisabledTitle') }} -

-

- {{ t('admin.settings.emailTabDisabledHint') }} -

+
-
- -
-
-
+ +
+

- {{ t('admin.settings.smtp.title') }} + {{ t("admin.settings.testEmail.title") }}

- {{ t('admin.settings.smtp.description') }} + {{ t("admin.settings.testEmail.description") }}

- -
-
-
-
- - -
-
- - -
-
- - -
-
- - +
+
+ + +
+
-
- - -
-
- - + class="btn btn-secondary" + > + + + + + {{ + sendingTestEmail + ? t("admin.settings.testEmail.sending") + : t("admin.settings.testEmail.sendTestEmail") + }} +
- - +
+ +
-
- -

- {{ t('admin.settings.smtp.useTlsHint') }} +

+ {{ t("admin.settings.balanceNotify.title") }} +

+

+ {{ t("admin.settings.balanceNotify.description") }} +

+
+
+
+ + +
+
+ +
+ $ + +
+

+ {{ t("admin.settings.balanceNotify.thresholdHint") }}

- -
- -
-
- - -
-
-

- {{ t('admin.settings.testEmail.title') }} -

-

- {{ t('admin.settings.testEmail.description') }} -

-
-
-
-
- +
+ +

+ {{ t("admin.settings.balanceNotify.rechargeUrlHint") }} +

- -
-
-
- -
-
-

- {{ t('admin.settings.balanceNotify.title') }} -

-

- {{ t('admin.settings.balanceNotify.description') }} -

-
-
-
- - -
-
- -
- $ - -
-

{{ t('admin.settings.balanceNotify.thresholdHint') }}

-
-
- - -

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

-
- -
-
-

- {{ t('admin.settings.quotaNotify.title') }} -

-

- {{ t('admin.settings.quotaNotify.description') }} -

-
-
-
- - + +
+
+

+ {{ t("admin.settings.quotaNotify.title") }} +

+

+ {{ t("admin.settings.quotaNotify.description") }} +

-
- -
-
- - - +
+
- +

+ {{ t("admin.settings.quotaNotify.emailsHint") }} +

-

{{ t('admin.settings.quotaNotify.emailsHint') }}

-
+
@@ -2784,8 +4712,17 @@
-
@@ -2818,149 +4759,208 @@ @close="showProviderDialog = false" @save="handleSaveProvider" /> - +
- diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index f61acc1ecdc192bffdd7d97144adfd3366508cad..610fa3ba1ab2c9bd21ebafc7eda14a33b5cad404 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,12 +11,17 @@

-
+
+
-

- {{ errors.email }} -

@@ -91,10 +93,7 @@
-

- {{ errors.password }} -

- + -

- {{ errors.turnstile }} -

- - -
-
-
- -
-

- {{ errorMessage }} -

-
-
-
-
-
-

- {{ t('auth.oidc.invitationRequired', { providerName }) }} -

-
- +
+
+
+
+

+ {{ t('auth.oauthFlow.profileDetailsTitle', { providerName }) }} +

+

+ {{ t('auth.oauthFlow.profileDetailsDescription', { providerName }) }} +

+
+ + + + +
- -

- {{ invitationError }} + + - -

-
-
- + + + + + + + + +
@@ -73,15 +243,26 @@ + + diff --git a/frontend/src/views/auth/WechatPaymentCallbackView.vue b/frontend/src/views/auth/WechatPaymentCallbackView.vue new file mode 100644 index 0000000000000000000000000000000000000000..225c84e13a6afa87453f145827186c1e07028666 --- /dev/null +++ b/frontend/src/views/auth/WechatPaymentCallbackView.vue @@ -0,0 +1,150 @@ + + + diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..9f67a994286cc1d85072b5a022054c8c374fafdf --- /dev/null +++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts @@ -0,0 +1,453 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import EmailVerifyView from '@/views/auth/EmailVerifyView.vue' + +const { + pushMock, + showSuccessMock, + showErrorMock, + registerMock, + setTokenMock, + setPendingAuthSessionMock, + clearPendingAuthSessionMock, + getPublicSettingsMock, + sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, + persistOAuthTokenContextMock, + apiClientPostMock, + authStoreState, +} = vi.hoisted(() => ({ + pushMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + registerMock: vi.fn(), + setTokenMock: vi.fn(), + setPendingAuthSessionMock: vi.fn(), + clearPendingAuthSessionMock: vi.fn(), + getPublicSettingsMock: vi.fn(), + sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), + persistOAuthTokenContextMock: vi.fn(), + apiClientPostMock: vi.fn(), + authStoreState: { + pendingAuthSession: null as null | { + token: string + token_field: 'pending_auth_token' | 'pending_oauth_token' + provider: string + redirect?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string + } + }, +})) + +vi.mock('vue-router', () => ({ + useRouter: () => ({ + push: pushMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.accountCreatedSuccess') { + return `Account created for ${params?.siteName ?? 'Sub2API'}` + } + return key + }, + locale: { value: 'en' }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + pendingAuthSession: authStoreState.pendingAuthSession, + register: (...args: any[]) => registerMock(...args), + setToken: (...args: any[]) => setTokenMock(...args), + setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args), + clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args), + }), + useAppStore: () => ({ + showSuccess: (...args: any[]) => showSuccessMock(...args), + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), + persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args), + } +}) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPostMock(...args), + }, +})) + +describe('EmailVerifyView', () => { + beforeEach(() => { + pushMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + registerMock.mockReset() + setTokenMock.mockReset() + setPendingAuthSessionMock.mockReset() + clearPendingAuthSessionMock.mockReset() + getPublicSettingsMock.mockReset() + sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() + persistOAuthTokenContextMock.mockReset() + apiClientPostMock.mockReset() + authStoreState.pendingAuthSession = null + sessionStorage.clear() + + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: [], + }) + sendVerifyCodeMock.mockResolvedValue({ countdown: 60 }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 }) + setTokenMock.mockResolvedValue({}) + }) + + it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-1', + token_field: 'pending_auth_token', + provider: 'wechat', + redirect: '/profile', + } + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-1', + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + }) + + it('skips the registration email suffix whitelist for pending oauth verification', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-2', + token_field: 'pending_auth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-2', + }) + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_oauth_token: undefined, + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ + auth_result: 'pending_session', + provider: 'oidc', + redirect: '/profile/security', + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-1', + token_field: 'pending_auth_token', + provider: 'wechat', + redirect: '/profile', + } + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'oauth-access-token', + refresh_token: 'oauth-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('123456') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'fresh@example.com', + password: 'secret-123', + verify_code: '123456', + }) + expect(persistOAuthTokenContextMock).toHaveBeenCalledWith({ + access_token: 'oauth-access-token', + refresh_token: 'oauth-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }) + expect(setTokenMock).toHaveBeenCalledWith('oauth-access-token') + expect(clearPendingAuthSessionMock).toHaveBeenCalled() + expect(pushMock).toHaveBeenCalledWith('/profile') + expect(registerMock).not.toHaveBeenCalled() + }) + + it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + apiClientPostMock.mockResolvedValue({ + data: { + auth_result: 'pending_session', + provider: 'oidc', + step: 'bind_login_required', + redirect: '/profile/security', + email: 'fresh@example.com', + }, + }) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('123456') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'fresh@example.com', + password: 'secret-123', + verify_code: '123456', + }) + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(setTokenMock).not.toHaveBeenCalled() + expect(persistOAuthTokenContextMock).not.toHaveBeenCalled() + expect(clearPendingAuthSessionMock).not.toHaveBeenCalled() + expect(showSuccessMock).not.toHaveBeenCalled() + }) + + it('keeps the normal email registration flow unchanged', async () => { + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'normal@example.com', + password: 'secret-456', + promo_code: 'PROMO', + invitation_code: 'INVITE', + }) + ) + registerMock.mockResolvedValue({}) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('654321') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(registerMock).toHaveBeenCalledWith({ + email: 'normal@example.com', + password: 'secret-456', + verify_code: '654321', + turnstile_token: undefined, + promo_code: 'PROMO', + invitation_code: 'INVITE', + }) + expect(apiClientPostMock).not.toHaveBeenCalled() + expect(pushMock).toHaveBeenCalledWith('/dashboard') + }) +}) diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..333f8dc54e04d22b46cd2cdd50b99d0abab16309 --- /dev/null +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -0,0 +1,739 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import LinuxDoCallbackView from '../LinuxDoCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const setPendingAuthSession = vi.fn() +const clearPendingAuthSession = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeLinuxDoOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() +const login2FA = vi.fn() +const apiClientPost = vi.fn() +const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + return key + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken, + setPendingAuthSession, + clearPendingAuthSession + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPost(...args) + } +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args), + login2FA: (...args: any[]) => login2FA(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) + } +}) + +describe('LinuxDoCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + setPendingAuthSession.mockReset() + clearPendingAuthSession.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeLinuxDoOAuthRegistration.mockReset() + getPublicSettings.mockReset() + login2FA.mockReset() + apiClientPost.mockReset() + sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() + getPublicSettings.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '' + }) + window.location.hash = '' + localStorage.clear() + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + window.location.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess') + expect(replace).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + apiClientPost.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + adopt_display_name: true, + adopt_avatar: true, + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code' + }) + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(replace).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('treats a completion without token as bind success and returns to profile', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile/security' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/profile/security', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('keeps rendering bind-login UI for legacy pending bind responses instead of treating them as success', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile/security', + email: 'existing@example.com' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setPendingAuthSession).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'linuxdo', + redirect: '/welcome' + }) + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[0].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + }) + + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('auth.oauthFlow.bindExistingAccount') + expect(wrapper.text()).toContain('auth.oauthFlow.createNewAccount') + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false + }) + expect(setToken).toHaveBeenCalledWith('new-access-token') + expect(replace).toHaveBeenCalledWith('/welcome') + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showError).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + sendPendingOAuthVerifyCode.mockResolvedValue({ + message: 'sent', + countdown: 60 + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ + email: 'new@example.com' + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="linuxdo-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="linuxdo-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="linuxdo-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true + }) + expect(setToken).toHaveBeenCalledWith('bind-access-token') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + apiClientPost.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com' + } + }) + login2FA.mockResolvedValue({ + access_token: '2fa-access-token' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="linuxdo-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="linuxdo-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FA).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="linuxdo-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="linuxdo-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456' + }) + expect(setToken).toHaveBeenCalledWith('2fa-access-token') + expect(replace).toHaveBeenCalledWith('/profile') + }) +}) diff --git a/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..1669f763f2cea57b5f6d8128974edca691161433 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OAuthCallbackView.spec.ts @@ -0,0 +1,68 @@ +import { mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import OAuthCallbackView from '@/views/auth/OAuthCallbackView.vue' + +const { routeState, showErrorMock, copyToClipboardMock } = vi.hoisted(() => ({ + routeState: { + query: {} as Record, + }, + showErrorMock: vi.fn(), + copyToClipboardMock: vi.fn(), +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +vi.mock('@/composables/useClipboard', () => ({ + useClipboard: () => ({ + copyToClipboard: (...args: any[]) => copyToClipboardMock(...args), + }), +})) + +describe('OAuthCallbackView', () => { + beforeEach(() => { + routeState.query = {} + showErrorMock.mockReset() + copyToClipboardMock.mockReset() + }) + + it('renders localized callback copy actions', () => { + routeState.query = { + code: 'oauth-code', + state: 'oauth-state', + } + + const wrapper = mount(OAuthCallbackView) + + expect(wrapper.text()).toContain('auth.oauth.callbackTitle') + expect(wrapper.text()).toContain('auth.oauth.callbackHint') + expect(wrapper.text()).toContain('common.copy') + expect(wrapper.find('input[value="oauth-code"]').exists()).toBe(true) + expect(wrapper.find('input[value="oauth-state"]').exists()).toBe(true) + }) + + it('sends callback errors to toast instead of rendering inline red text', () => { + routeState.query = { + error: 'oauth failed', + } + + const wrapper = mount(OAuthCallbackView) + + expect(showErrorMock).toHaveBeenCalledWith('oauth failed') + expect(wrapper.text()).not.toContain('oauth failed') + expect(wrapper.find('.bg-red-50').exists()).toBe(false) + }) +}) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec89512b3e0178b39f60e2be2670ac7daad134c2 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -0,0 +1,689 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import OidcCallbackView from '../OidcCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const setPendingAuthSession = vi.fn() +const clearPendingAuthSession = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeOIDCOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() +const login2FA = vi.fn() +const apiClientPost = vi.fn() +const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + if (!params?.providerName) { + return key + } + return `${key}:${params.providerName}` + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken, + setPendingAuthSession, + clearPendingAuthSession + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPost(...args) + } +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args), + login2FA: (...args: any[]) => login2FA(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) + } +}) + +describe('OidcCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + setPendingAuthSession.mockReset() + clearPendingAuthSession.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeOIDCOAuthRegistration.mockReset() + getPublicSettings.mockReset() + login2FA.mockReset() + apiClientPost.mockReset() + sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID', + turnstile_enabled: false, + turnstile_site_key: '' + }) + window.location.hash = '' + localStorage.clear() + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + window.location.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccess).toHaveBeenCalledWith('auth.loginSuccess') + expect(replace).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + window.location.hash = '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + apiClientPost.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + adopt_display_name: true, + adopt_avatar: true, + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code' + }) + expect(setToken).toHaveBeenCalledWith('legacy-access-token') + expect(replace).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[0].setValue(false) + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: false, + adoptAvatar: true + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/profile', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccess).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setPendingAuthSession).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/welcome' + }) + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + }) + + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('auth.oauthFlow.bindExistingAccount') + expect(wrapper.text()).toContain('auth.oauthFlow.createNewAccount') + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID', + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false + }) + expect(setToken).toHaveBeenCalledWith('new-access-token') + expect(replace).toHaveBeenCalledWith('/welcome') + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showError).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + sendPendingOAuthVerifyCode.mockResolvedValue({ + message: 'sent', + countdown: 60 + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ + email: 'new@example.com' + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer' + } + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="oidc-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="oidc-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="oidc-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPost).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true + }) + expect(setToken).toHaveBeenCalledWith('bind-access-token') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + apiClientPost.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com' + } + }) + login2FA.mockResolvedValue({ + access_token: '2fa-access-token' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.get('[data-testid="oidc-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="oidc-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FA).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="oidc-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="oidc-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456' + }) + expect(setToken).toHaveBeenCalledWith('2fa-access-token') + expect(replace).toHaveBeenCalledWith('/profile') + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..7150dd7ec261823680161990752ef23c53a0ec26 --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -0,0 +1,1089 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' + +const { + exchangePendingOAuthCompletionMock, + completeWeChatOAuthRegistrationMock, + login2FAMock, + apiClientPostMock, + sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, + getPublicSettingsMock, + prepareOAuthBindAccessTokenCookieMock, + getAuthTokenMock, + replaceMock, + setTokenMock, + setPendingAuthSessionMock, + clearPendingAuthSessionMock, + showSuccessMock, + showErrorMock, + fetchPublicSettingsMock, + routeState, + locationState, + appStoreState, +} = vi.hoisted(() => ({ + exchangePendingOAuthCompletionMock: vi.fn(), + completeWeChatOAuthRegistrationMock: vi.fn(), + login2FAMock: vi.fn(), + apiClientPostMock: vi.fn(), + sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), + getPublicSettingsMock: vi.fn(), + prepareOAuthBindAccessTokenCookieMock: vi.fn(), + getAuthTokenMock: vi.fn(), + replaceMock: vi.fn(), + setTokenMock: vi.fn(), + setPendingAuthSessionMock: vi.fn(), + clearPendingAuthSessionMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + fetchPublicSettingsMock: vi.fn(), + routeState: { + query: {} as Record, + }, + locationState: { + current: { + href: 'http://localhost/auth/wechat/callback', + hash: '', + search: '', + pathname: '/auth/wechat/callback' + } as { href: string; hash: string; search: string; pathname: string }, + }, + appStoreState: { + cachedPublicSettings: null as null | Record, + publicSettingsLoaded: false, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oauthFlow.totpHint') { + return `verify ${params?.account ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackTitle') { + return `Signing you in with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackProcessing') { + return `Completing login with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.invitationRequired') { + return `${params?.providerName ?? ''} invitation required`.trim() + } + if (key === 'auth.oidc.completeRegistration') { + return 'Complete registration' + } + if (key === 'auth.oidc.completing') { + return 'Completing' + } + if (key === 'auth.oidc.backToLogin') { + return 'Back to login' + } + if (key === 'auth.invitationCodePlaceholder') { + return 'Invitation code' + } + if (key === 'auth.loginSuccess') { + return 'Login success' + } + if (key === 'auth.loginFailed') { + return 'Login failed' + } + if (key === 'auth.oidc.callbackHint') { + return 'Callback hint' + } + if (key === 'auth.oidc.callbackMissingToken') { + return 'Missing login token' + } + if (key === 'auth.oidc.completeRegistrationFailed') { + return 'Complete registration failed' + } + return key + }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken: setTokenMock, + setPendingAuthSession: setPendingAuthSessionMock, + clearPendingAuthSession: clearPendingAuthSessionMock, + }), + useAppStore: () => ({ + ...appStoreState, + showSuccess: showSuccessMock, + showError: showErrorMock, + fetchPublicSettings: fetchPublicSettingsMock, + }), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: (...args: any[]) => apiClientPostMock(...args), + }, +})) + +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args), + completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), + login2FA: (...args: any[]) => login2FAMock(...args), + sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), + getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), + prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args), + getAuthToken: (...args: any[]) => getAuthTokenMock(...args), + } +}) + +describe('WechatCallbackView', () => { + beforeEach(() => { + exchangePendingOAuthCompletionMock.mockReset() + completeWeChatOAuthRegistrationMock.mockReset() + login2FAMock.mockReset() + apiClientPostMock.mockReset() + sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() + getPublicSettingsMock.mockReset() + replaceMock.mockReset() + setTokenMock.mockReset() + setPendingAuthSessionMock.mockReset() + clearPendingAuthSessionMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + prepareOAuthBindAccessTokenCookieMock.mockReset() + getAuthTokenMock.mockReset() + fetchPublicSettingsMock.mockReset() + routeState.query = {} + appStoreState.cachedPublicSettings = null + appStoreState.publicSettingsLoaded = false + localStorage.clear() + locationState.current = { + href: 'http://localhost/auth/wechat/callback', + hash: '', + search: '', + pathname: '/auth/wechat/callback' + } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + }) + }) + + it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'mp', + redirect: '/profile', + } + appStoreState.cachedPublicSettings = { + wechat_oauth_enabled: true, + wechat_oauth_open_enabled: true, + wechat_oauth_mp_enabled: false, + } + appStoreState.publicSettingsLoaded = true + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).not.toContain('mode=mp') + }) + + it('falls back to the query mode when capability settings cannot be confirmed', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'mp', + redirect: '/profile', + } + fetchPublicSettingsMock.mockResolvedValue(null) + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=mp') + }) + + it('ignores legacy aggregate wechat settings and reuses the query mode during bind recovery', async () => { + routeState.query = { + wechat_bind_existing: '1', + mode: 'open', + redirect: '/profile', + } + appStoreState.cachedPublicSettings = { + wechat_oauth_enabled: true, + } + appStoreState.publicSettingsLoaded = true + getAuthTokenMock.mockReturnValue('current-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('mode=open') + }) + + it('accepts the legacy fragment token success callback without pending-session exchange', async () => { + locationState.current.hash = + '#access_token=legacy-access-token&refresh_token=legacy-refresh-token&expires_in=3600&token_type=Bearer&redirect=%2Flegacy-dashboard' + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token') + expect(localStorage.getItem('refresh_token')).toBe('legacy-refresh-token') + expect(localStorage.getItem('token_expires_at')).not.toBeNull() + expect(showSuccessMock).toHaveBeenCalledWith('Login success') + expect(replaceMock).toHaveBeenCalledWith('/legacy-dashboard') + }) + + it('accepts the legacy pending oauth invitation fragment without pending-session exchange', async () => { + locationState.current.hash = + '#error=invitation_required&pending_oauth_token=legacy-pending-token&redirect=%2Flegacy-invite' + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'legacy-access-token', + refresh_token: 'legacy-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + pending_oauth_token: 'legacy-pending-token', + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('legacy-access-token') + expect(replaceMock).toHaveBeenCalledWith('/legacy-invite') + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith() + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1) + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') + expect(replaceMock).toHaveBeenCalledWith('/dashboard') + expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + redirect: '/profile/connections', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(clearPendingAuthSessionMock).toHaveBeenCalledTimes(1) + expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replaceMock).toHaveBeenCalledWith('/profile/connections') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ') + await wrapper.get('button').trigger('click') + await flushPromises() + + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', { + adoptDisplayName: false, + adoptAvatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') + expect(replaceMock).toHaveBeenCalledWith('/subscriptions') + }) + + it('keeps the oauth flow active when complete-registration returns another pending step', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + auth_result: 'pending_session', + step: 'choose_account_action_required', + redirect: '/dashboard', + email: 'fresh@example.com', + resolved_email: 'fresh@example.com', + force_email_on_signup: true, + adoption_required: true, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + await flushPromises() + + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + expect(wrapper.get('[data-testid="wechat-choice-bind-existing"]').exists()).toBe(true) + expect(wrapper.get('[data-testid="wechat-choice-create-account"]').exists()).toBe(true) + }) + + it('offers existing-account email collection during invitation flow', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue(null) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const emailInput = wrapper.get('[data-testid="existing-account-email"]') + await emailInput.setValue('user@example.com') + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click') + + expect(replaceMock).toHaveBeenCalledTimes(1) + expect(replaceMock.mock.calls[0]?.[0]).toContain('/login?') + expect(replaceMock.mock.calls[0]?.[0]).toContain('wechat_bind_existing%3D1') + expect(replaceMock.mock.calls[0]?.[0]).toContain('email=user%40example.com') + expect(replaceMock.mock.calls[0]?.[0]).toContain('mode%3Dopen') + }) + + it('binds directly to the current signed-in account during invitation flow', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue('current-auth-token') + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.find('[data-testid="existing-account-email"]').exists()).toBe(false) + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click') + + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fusage') + expect(locationState.current.href).toContain('mode=open') + }) + + it('shows an error and stays on the page when preparing bind-token for the current account fails', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/usage', + }) + getAuthTokenMock.mockReturnValue('current-auth-token') + prepareOAuthBindAccessTokenCookieMock.mockRejectedValue(new Error('bind token failed')) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="existing-account-submit"]').trigger('click').catch(() => undefined) + await flushPromises() + + expect(showErrorMock).toHaveBeenCalledWith('bind token failed') + expect(locationState.current.href).toBe('http://localhost/auth/wechat/callback') + }) + + it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '', + }) + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'new-access-token', + refresh_token: 'new-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'new@example.com', + password: 'secret-123', + verify_code: '246810', + invitation_code: 'INVITE123', + adopt_display_name: true, + adopt_avatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('new-access-token') + expect(replaceMock).toHaveBeenCalledWith('/welcome') + }) + + it('persists a pending auth session when the oauth flow still needs account creation', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'wechat', + redirect: '/welcome', + }) + }) + + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + apiClientPostMock.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists', + }, + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('shows create-account failures through toast without inline error text', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + apiClientPostMock.mockRejectedValue(new Error('create failed')) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('new@example.com') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect(showErrorMock).toHaveBeenCalledWith('create failed') + expect(wrapper.text()).not.toContain('create failed') + }) + + it('sends a verify code for pending oauth account creation', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ + message: 'sent', + countdown: 60, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') + await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click') + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'new@example.com', + }) + }) + + it('shows bind-login form for existing account binding and submits credentials with adoption decisions', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/profile/security', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + access_token: 'bind-access-token', + refresh_token: 'bind-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('[data-testid="wechat-bind-login-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="wechat-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="wechat-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/bind-login', { + email: 'existing@example.com', + password: 'secret-password', + adopt_display_name: false, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('bind-access-token') + expect(replaceMock).toHaveBeenCalledWith('/profile/security') + }) + + it('allows switching from server-driven bind-login to create-account mode', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/welcome', + email: 'existing@example.com', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('button.btn-secondary').trigger('click') + await flushPromises() + + const createAccountEmail = wrapper.get('[data-testid="wechat-create-account-email"]') + expect((createAccountEmail.element as HTMLInputElement).value).toBe('existing@example.com') + }) + + it('reuses query email for bind-login when backend does not echo it back', async () => { + routeState.query = { + email: 'resume@example.com', + } + exchangePendingOAuthCompletionMock.mockResolvedValue({ + step: 'bind_login_required', + redirect: '/profile', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + const bindEmail = wrapper.get('[data-testid="wechat-bind-login-email"]') + expect((bindEmail.element as HTMLInputElement).value).toBe('resume@example.com') + }) + + it('keeps rendering pending bind-login UI when adoption confirmation leads to another pending step', async () => { + exchangePendingOAuthCompletionMock + .mockResolvedValueOnce({ + redirect: '/profile', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + .mockResolvedValueOnce({ + step: 'bind_login_required', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(showSuccessMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + + it('handles bind-login 2FA challenge before redirecting', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'adopt_existing_user_by_email', + redirect: '/profile', + email: 'existing@example.com', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + apiClientPostMock.mockResolvedValue({ + data: { + requires_2fa: true, + temp_token: 'temp-123', + user_email_masked: 'o***g@example.com', + }, + }) + login2FAMock.mockResolvedValue({ + access_token: '2fa-access-token', + refresh_token: '2fa-refresh-token', + expires_in: 3600, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + await wrapper.get('[data-testid="wechat-bind-login-password"]').setValue('secret-password') + await wrapper.get('[data-testid="wechat-bind-login-submit"]').trigger('click') + await flushPromises() + + expect(wrapper.text()).toContain('o***g@example.com') + expect(login2FAMock).not.toHaveBeenCalled() + + await wrapper.get('[data-testid="wechat-bind-login-totp"]').setValue('123456') + await wrapper.get('[data-testid="wechat-bind-login-totp-submit"]').trigger('click') + await flushPromises() + + expect(login2FAMock).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '123456', + }) + expect(setTokenMock).toHaveBeenCalledWith('2fa-access-token') + expect(replaceMock).toHaveBeenCalledWith('/profile') + expect(localStorage.getItem('refresh_token')).toBe('2fa-refresh-token') + }) + + it('restarts the current-user bind flow after returning from login', async () => { + routeState.query = { + wechat_bind_existing: '1', + redirect: '/profile', + mode: 'mp', + } + getAuthTokenMock.mockReturnValue('existing-auth-token') + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(prepareOAuthBindAccessTokenCookieMock).toHaveBeenCalledTimes(1) + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=mp') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') + }) + + it('redirects back to login instead of falling through when bind-existing resume has no auth token', async () => { + routeState.query = { + wechat_bind_existing: '1', + redirect: '/profile', + mode: 'mp', + email: 'resume@example.com', + } + getAuthTokenMock.mockReturnValue(null) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).not.toHaveBeenCalled() + expect(replaceMock).toHaveBeenCalledTimes(1) + expect(replaceMock.mock.calls[0]?.[0]).toContain('/login?') + expect(replaceMock.mock.calls[0]?.[0]).toContain('wechat_bind_existing%3D1') + expect(replaceMock.mock.calls[0]?.[0]).toContain('mode%3Dmp') + expect(replaceMock.mock.calls[0]?.[0]).toContain('email=resume%40example.com') + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..822a083b79386619a0eddd5356e5c07c67cc5d5c --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatPaymentCallbackView.spec.ts @@ -0,0 +1,116 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatPaymentCallbackView from '@/views/auth/WechatPaymentCallbackView.vue' + +const { replaceMock, routeState, locationState, showErrorMock } = vi.hoisted(() => ({ + replaceMock: vi.fn(), + routeState: { + query: {} as Record, + }, + locationState: { + current: { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string }, + }, + showErrorMock: vi.fn(), +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => { + if (key === 'auth.wechatPayment.callbackTitle') return '正在恢复微信支付' + if (key === 'auth.wechatPayment.callbackProcessing') return '正在恢复微信支付...' + if (key === 'auth.wechatPayment.backToPayment') return '返回支付页' + if (key === 'auth.wechatPayment.callbackMissingResumeToken') return '微信支付回调缺少恢复令牌。' + return key + }, + locale: { value: 'zh-CN' }, + }), +})) + +vi.mock('@/stores', () => ({ + useAppStore: () => ({ + showError: (...args: any[]) => showErrorMock(...args), + }), +})) + +describe('WechatPaymentCallbackView', () => { + beforeEach(() => { + replaceMock.mockReset() + showErrorMock.mockReset() + routeState.query = {} + locationState.current = { + href: 'http://localhost/auth/wechat/payment/callback', + hash: '', + search: '', + pathname: '/auth/wechat/payment/callback', + origin: 'http://localhost', + } as Location & { origin: string } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + }) + + it('redirects back to purchase with an opaque resume token from hash fragment', async () => { + locationState.current.hash = '#wechat_resume_token=resume-token-123&redirect=%2Fpurchase%3Ffrom%3Dwechat' + + mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).toHaveBeenCalledWith({ + path: '/purchase', + query: { + from: 'wechat', + wechat_resume: '1', + wechat_resume_token: 'resume-token-123', + }, + }) + }) + + it('redirects legacy openid callback payloads back to purchase while preserving resume context', async () => { + locationState.current.hash = + '#openid=openid-123&state=oauth-state&scope=snsapi_base&payment_type=wxpay_direct&amount=128&order_type=subscription&plan_id=7&redirect=%2Fpayment%3Ffrom%3Dwechat' + + mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).toHaveBeenCalledWith({ + path: '/purchase', + query: { + from: 'wechat', + wechat_resume: '1', + openid: 'openid-123', + state: 'oauth-state', + scope: 'snsapi_base', + payment_type: 'wxpay_direct', + amount: '128', + order_type: 'subscription', + plan_id: '7', + }, + }) + }) + + it('shows an error when the callback payload is missing the resume token', async () => { + locationState.current.hash = '#payment_type=wxpay' + + const wrapper = mount(WechatPaymentCallbackView) + await flushPromises() + + expect(replaceMock).not.toHaveBeenCalled() + expect(showErrorMock).toHaveBeenCalledWith('微信支付回调缺少恢复令牌。') + expect(wrapper.text()).toContain('微信支付回调缺少恢复令牌。') + expect(wrapper.find('.bg-red-50').exists()).toBe(false) + }) +}) diff --git a/frontend/src/views/user/AvailableChannelsView.vue b/frontend/src/views/user/AvailableChannelsView.vue new file mode 100644 index 0000000000000000000000000000000000000000..a6c9ebc89b9acab616c1a66efa51c5f07a2fd999 --- /dev/null +++ b/frontend/src/views/user/AvailableChannelsView.vue @@ -0,0 +1,127 @@ + + + diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue new file mode 100644 index 0000000000000000000000000000000000000000..61b4da97db9768df5f89c13bca48d7b90b4c13d5 --- /dev/null +++ b/frontend/src/views/user/ChannelStatusView.vue @@ -0,0 +1,172 @@ + + + diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 34cccf9c53ee469c7802a3089c315fb0a7ca6208..cf29e4bd54c2b437e1c10ee63cdee039c1a82a25 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -61,7 +61,7 @@ diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da52cb6a46ade857bdc662e04c76523ce2..7cb4343df4e814d89f84e164a7fa07fae784ed9e 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -23,20 +23,7 @@ :order-type="paymentState.orderType" @done="onPaymentDone" @success="onPaymentSuccess" - /> - - @@ -99,9 +86,6 @@ {{ t('payment.createOrder') }} ¥{{ totalAmount.toFixed(2) }} -
-

{{ errorMessage }}

-
@@ -185,9 +169,6 @@ {{ t('payment.createOrder') }} ¥{{ (feeRate > 0 ? subTotalAmount : selectedPlan.price).toFixed(2) }} -
-

{{ errorMessage }}

-