Commit b9b4db3d authored by song's avatar song
Browse files

Merge upstream/main

parents 5a6f60a9 dae0d532
...@@ -14,6 +14,9 @@ backend/server ...@@ -14,6 +14,9 @@ backend/server
backend/sub2api backend/sub2api
backend/main backend/main
# Go 测试二进制
*.test
# 测试覆盖率 # 测试覆盖率
*.out *.out
coverage.html coverage.html
...@@ -80,6 +83,8 @@ temp/ ...@@ -80,6 +83,8 @@ temp/
*.log *.log
*.bak *.bak
.cache/ .cache/
.dev/
.serena/
# =================== # ===================
# 构建产物 # 构建产物
...@@ -123,6 +128,5 @@ backend/cmd/server/server ...@@ -123,6 +128,5 @@ backend/cmd/server/server
deploy/docker-compose.override.yml deploy/docker-compose.override.yml
.gocache/ .gocache/
vite.config.js vite.config.js
!docs/
docs/* docs/*
!docs/dependency-security.md .serena/
\ No newline at end of file
## 概述
全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
## 主要改动
### 1. 错误日志查询优化
**功能特性:**
- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
- 改进查询参数处理,简化代码结构
- 增强错误分类和标准化处理
- 支持错误解决状态追踪(resolved 字段)
**技术实现:**
- `ops_handler.go` - 新增单条错误日志查询接口
- `ops_repo.go` - 优化数据查询和过滤条件构建
- `ops_models.go` - 扩展错误日志数据模型
- 前端 API 接口同步更新
### 2. 告警静默功能
**功能特性:**
- 支持按规则、平台、分组、区域等维度静默告警
- 可设置静默时长和原因说明
- 静默记录可追溯,记录创建人和创建时间
- 自动过期机制,避免永久静默
**技术实现:**
- `037_ops_alert_silences.sql` - 新增告警静默表
- `ops_alerts.go` - 告警静默逻辑实现
- `ops_alerts_handler.go` - 告警静默 API 接口
- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
**数据库结构:**
| 字段 | 类型 | 说明 |
|------|------|------|
| rule_id | BIGINT | 告警规则 ID |
| platform | VARCHAR(64) | 平台标识 |
| group_id | BIGINT | 分组 ID(可选) |
| region | VARCHAR(64) | 区域(可选) |
| until | TIMESTAMPTZ | 静默截止时间 |
| reason | TEXT | 静默原因 |
| created_by | BIGINT | 创建人 ID |
### 3. 错误分类标准化
**功能特性:**
- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
- 规范错误归属分类(client|provider|platform)
- 标准化错误来源分类(client_request|upstream_http|gateway)
- 自动迁移历史数据到新分类体系
**技术实现:**
- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
- 自动映射历史遗留分类到新标准
- 自动解决已恢复的上游错误(客户端状态码 < 400)
### 4. Gateway 服务集成
**功能特性:**
- 完善各 Gateway 服务的 Ops 集成
- 统一错误日志记录接口
- 增强上游错误追踪能力
**涉及服务:**
- `antigravity_gateway_service.go` - Antigravity 网关集成
- `gateway_service.go` - 通用网关集成
- `gemini_messages_compat_service.go` - Gemini 兼容层集成
- `openai_gateway_service.go` - OpenAI 网关集成
### 5. 前端 UI 优化
**代码重构:**
- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
- 优化错误日志表格组件,提升可读性
- 清理未使用的 i18n 翻译,减少冗余
- 统一组件代码风格和格式
- 优化骨架屏组件,更好匹配实际看板布局
**布局改进:**
- 修复模态框内容溢出和滚动问题
- 优化表格布局,使用 flex 布局确保正确显示
- 改进看板头部布局和交互
- 提升响应式体验
- 骨架屏支持全屏模式适配
**交互优化:**
- 优化告警事件卡片功能和展示
- 改进错误详情展示逻辑
- 增强请求详情模态框
- 完善运行时设置卡片
- 改进加载动画效果
### 6. 国际化完善
**文案补充:**
- 补充错误日志相关的英文翻译
- 添加告警静默功能的中英文文案
- 完善提示文本和错误信息
- 统一术语翻译标准
## 文件变更
**后端(26 个文件):**
- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
- `backend/internal/repository/ops_repo.go` - 数据访问层重构
- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
- `backend/internal/server/routes/admin.go` - 路由配置更新
- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
- 测试文件更新(5 个文件)
**前端(13 个文件):**
- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
- `frontend/src/api/admin/ops.ts` - API 接口扩展
- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
## 代码统计
- 44 个文件修改
- 3733 行新增
- 995 行删除
- 净增加 2738 行
## 核心改进
**可维护性提升:**
- 重构核心服务层,职责更清晰
- 简化前端组件代码,降低复杂度
- 统一代码风格和命名规范
- 清理冗余代码和未使用的翻译
- 标准化错误分类体系
**功能完善:**
- 告警静默功能,减少告警噪音
- 错误日志查询优化,提升运维效率
- Gateway 服务集成完善,统一监控能力
- 错误解决状态追踪,便于问题管理
**用户体验优化:**
- 修复多个 UI 布局问题
- 优化交互流程
- 完善国际化支持
- 提升响应式体验
- 改进加载状态展示
## 测试验证
- ✅ 错误日志查询和过滤功能
- ✅ 告警静默创建和自动过期
- ✅ 错误分类标准化迁移
- ✅ Gateway 服务错误日志记录
- ✅ 前端组件布局和交互
- ✅ 骨架屏全屏模式适配
- ✅ 国际化文本完整性
- ✅ API 接口功能正确性
- ✅ 数据库迁移执行成功
...@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( ...@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
--- ---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
---
## 部署方式 ## 部署方式
### 方式一:脚本安装(推荐) ### 方式一:脚本安装(推荐)
......
.cache/
.DS_Store
...@@ -18,6 +18,12 @@ linters: ...@@ -18,6 +18,12 @@ linters:
list-mode: original list-mode: original
files: files:
- "**/internal/service/**" - "**/internal/service/**"
- "!**/internal/service/ops_aggregation_service.go"
- "!**/internal/service/ops_alert_evaluator_service.go"
- "!**/internal/service/ops_cleanup_service.go"
- "!**/internal/service/ops_metrics_collector.go"
- "!**/internal/service/ops_scheduled_report_service.go"
- "!**/internal/service/wire.go"
deny: deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository" desc: "service must not import repository"
......
...@@ -33,7 +33,7 @@ func main() { ...@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil) authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
......
...@@ -62,6 +62,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { ...@@ -62,6 +62,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup( func provideCleanup(
entClient *ent.Client, entClient *ent.Client,
rdb *redis.Client, rdb *redis.Client,
opsMetricsCollector *service.OpsMetricsCollector,
opsAggregation *service.OpsAggregationService,
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
pricing *service.PricingService, pricing *service.PricingService,
...@@ -81,6 +87,42 @@ func provideCleanup( ...@@ -81,6 +87,42 @@ func provideCleanup(
name string name string
fn func() error fn func() error
}{ }{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
}
return nil
}},
{"OpsCleanupService", func() error {
if opsCleanup != nil {
opsCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
}
return nil
}},
{"OpsAggregationService", func() error {
if opsAggregation != nil {
opsAggregation.Stop()
}
return nil
}},
{"OpsMetricsCollector", func() error {
if opsMetricsCollector != nil {
opsMetricsCollector.Stop()
}
return nil
}},
{"SchedulerSnapshotService", func() error {
if schedulerSnapshot != nil {
schedulerSnapshot.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil
......
...@@ -51,33 +51,44 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -51,33 +51,44 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) promoCodeRepository := repository.NewPromoCodeRepository(client)
userService := service.NewUserService(userRepository) billingCache := repository.NewBillingCache(redisClient)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
userHandler := handler.NewUserHandler(userService) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
apiKeyRepository := repository.NewAPIKeyRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client) redeemCodeRepository := repository.NewRedeemCodeRepository(client)
billingCache := repository.NewBillingCache(redisClient)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository) dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardHandler := admin.NewDashboardHandler(dashboardService) dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
timingWheelService, err := service.ProvideTimingWheelService()
if err != nil {
return nil, err
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db) accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
...@@ -90,12 +101,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -90,12 +101,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher() claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache() usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
...@@ -105,14 +118,36 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -105,14 +118,36 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService) promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
schedulerCache := repository.NewSchedulerCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient) updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo) serviceBuildInfo := provideServiceBuildInfo(buildInfo)
...@@ -124,32 +159,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -124,32 +159,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
...@@ -174,6 +201,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { ...@@ -174,6 +201,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
func provideCleanup( func provideCleanup(
entClient *ent.Client, entClient *ent.Client,
rdb *redis.Client, rdb *redis.Client,
opsMetricsCollector *service.OpsMetricsCollector,
opsAggregation *service.OpsAggregationService,
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
pricing *service.PricingService, pricing *service.PricingService,
...@@ -192,6 +225,42 @@ func provideCleanup( ...@@ -192,6 +225,42 @@ func provideCleanup(
name string name string
fn func() error fn func() error
}{ }{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
}
return nil
}},
{"OpsCleanupService", func() error {
if opsCleanup != nil {
opsCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
}
return nil
}},
{"OpsAggregationService", func() error {
if opsAggregation != nil {
opsAggregation.Stop()
}
return nil
}},
{"OpsMetricsCollector", func() error {
if opsMetricsCollector != nil {
opsMetricsCollector.Stop()
}
return nil
}},
{"SchedulerSnapshotService", func() error {
if schedulerSnapshot != nil {
schedulerSnapshot.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil
......
...@@ -43,6 +43,8 @@ type Account struct { ...@@ -43,6 +43,8 @@ type Account struct {
Concurrency int `json:"concurrency,omitempty"` Concurrency int `json:"concurrency,omitempty"`
// Priority holds the value of the "priority" field. // Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"` Priority int `json:"priority,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// Status holds the value of the "status" field. // Status holds the value of the "status" field.
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// ErrorMessage holds the value of the "error_message" field. // ErrorMessage holds the value of the "error_message" field.
...@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) { ...@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte) values[i] = new([]byte)
case account.FieldAutoPauseOnExpired, account.FieldSchedulable: case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case account.FieldRateMultiplier:
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
...@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error { ...@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Priority = int(value.Int64) _m.Priority = int(value.Int64)
} }
case account.FieldRateMultiplier:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
case account.FieldStatus: case account.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i]) return fmt.Errorf("unexpected type %T for field status", values[i])
...@@ -420,6 +430,9 @@ func (_m *Account) String() string { ...@@ -420,6 +430,9 @@ func (_m *Account) String() string {
builder.WriteString("priority=") builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority)) builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
builder.WriteString("status=") builder.WriteString("status=")
builder.WriteString(_m.Status) builder.WriteString(_m.Status)
builder.WriteString(", ") builder.WriteString(", ")
......
...@@ -39,6 +39,8 @@ const ( ...@@ -39,6 +39,8 @@ const (
FieldConcurrency = "concurrency" FieldConcurrency = "concurrency"
// FieldPriority holds the string denoting the priority field in the database. // FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority" FieldPriority = "priority"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
// FieldStatus holds the string denoting the status field in the database. // FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status" FieldStatus = "status"
// FieldErrorMessage holds the string denoting the error_message field in the database. // FieldErrorMessage holds the string denoting the error_message field in the database.
...@@ -116,6 +118,7 @@ var Columns = []string{ ...@@ -116,6 +118,7 @@ var Columns = []string{
FieldProxyID, FieldProxyID,
FieldConcurrency, FieldConcurrency,
FieldPriority, FieldPriority,
FieldRateMultiplier,
FieldStatus, FieldStatus,
FieldErrorMessage, FieldErrorMessage,
FieldLastUsedAt, FieldLastUsedAt,
...@@ -174,6 +177,8 @@ var ( ...@@ -174,6 +177,8 @@ var (
DefaultConcurrency int DefaultConcurrency int
// DefaultPriority holds the default value on creation for the "priority" field. // DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int DefaultPriority int
// DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
DefaultRateMultiplier float64
// DefaultStatus holds the default value on creation for the "status" field. // DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save. // StatusValidator is a validator for the "status" field. It is called by the builders before save.
...@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption { ...@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc() return sql.OrderByField(FieldPriority, opts...).ToFunc()
} }
// ByRateMultiplier orders the results by the rate_multiplier field.
func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
// ByStatus orders the results by the status field. // ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption { func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc() return sql.OrderByField(FieldStatus, opts...).ToFunc()
......
...@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account { ...@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v)) return predicate.Account(sql.FieldEQ(FieldPriority, v))
} }
// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
func RateMultiplier(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. // Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Account { func Status(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v)) return predicate.Account(sql.FieldEQ(FieldStatus, v))
...@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account { ...@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldPriority, v)) return predicate.Account(sql.FieldLTE(FieldPriority, v))
} }
// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
func RateMultiplierEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
}
// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
func RateMultiplierNEQ(v float64) predicate.Account {
return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
}
// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
func RateMultiplierIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
}
// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
func RateMultiplierNotIn(vs ...float64) predicate.Account {
return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
}
// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
func RateMultiplierGT(v float64) predicate.Account {
return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
}
// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
func RateMultiplierGTE(v float64) predicate.Account {
return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
}
// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
func RateMultiplierLT(v float64) predicate.Account {
return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
}
// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
func RateMultiplierLTE(v float64) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
}
// StatusEQ applies the EQ predicate on the "status" field. // StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Account { func StatusEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v)) return predicate.Account(sql.FieldEQ(FieldStatus, v))
......
...@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate { ...@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
return _c return _c
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
_c.mutation.SetRateMultiplier(v)
return _c
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
if v != nil {
_c.SetRateMultiplier(*v)
}
return _c
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (_c *AccountCreate) SetStatus(v string) *AccountCreate { func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
_c.mutation.SetStatus(v) _c.mutation.SetStatus(v)
...@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error { ...@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
v := account.DefaultPriority v := account.DefaultPriority
_c.mutation.SetPriority(v) _c.mutation.SetPriority(v)
} }
if _, ok := _c.mutation.RateMultiplier(); !ok {
v := account.DefaultRateMultiplier
_c.mutation.SetRateMultiplier(v)
}
if _, ok := _c.mutation.Status(); !ok { if _, ok := _c.mutation.Status(); !ok {
v := account.DefaultStatus v := account.DefaultStatus
_c.mutation.SetStatus(v) _c.mutation.SetStatus(v)
...@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error { ...@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
if _, ok := _c.mutation.Priority(); !ok { if _, ok := _c.mutation.Priority(); !ok {
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)} return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
} }
if _, ok := _c.mutation.RateMultiplier(); !ok {
return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
}
if _, ok := _c.mutation.Status(); !ok { if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)} return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
} }
...@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { ...@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldPriority, field.TypeInt, value) _spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value _node.Priority = value
} }
if value, ok := _c.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
if value, ok := _c.mutation.Status(); ok { if value, ok := _c.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value) _spec.SetField(account.FieldStatus, field.TypeString, value)
_node.Status = value _node.Status = value
...@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert { ...@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
return u return u
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
u.Set(account.FieldRateMultiplier, v)
return u
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
u.SetExcluded(account.FieldRateMultiplier)
return u
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
u.Add(account.FieldRateMultiplier, v)
return u
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert { func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
u.Set(account.FieldStatus, v) u.Set(account.FieldStatus, v)
...@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne { ...@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
}) })
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne { func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) { return u.Update(func(s *AccountUpsert) {
...@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk { ...@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
}) })
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.SetRateMultiplier(v)
})
}
// AddRateMultiplier adds v to the "rate_multiplier" field.
func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.AddRateMultiplier(v)
})
}
// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
s.UpdateRateMultiplier()
})
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk { func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) { return u.Update(func(s *AccountUpsert) {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
...@@ -31,6 +32,7 @@ type AccountQuery struct { ...@@ -31,6 +32,7 @@ type AccountQuery struct {
withProxy *ProxyQuery withProxy *ProxyQuery
withUsageLogs *UsageLogQuery withUsageLogs *UsageLogQuery
withAccountGroups *AccountGroupQuery withAccountGroups *AccountGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -495,6 +497,9 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco ...@@ -495,6 +497,9 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -690,6 +695,9 @@ func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGro ...@@ -690,6 +695,9 @@ func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGro
func (_q *AccountQuery) sqlCount(ctx context.Context) (int, error) { func (_q *AccountQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -755,6 +763,9 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -755,6 +763,9 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -772,6 +783,32 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -772,6 +783,32 @@ func (_q *AccountQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AccountQuery) ForUpdate(opts ...sql.LockOption) *AccountQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AccountQuery) ForShare(opts ...sql.LockOption) *AccountQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AccountGroupBy is the group-by builder for Account entities. // AccountGroupBy is the group-by builder for Account entities.
type AccountGroupBy struct { type AccountGroupBy struct {
selector selector
......
...@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate { ...@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
return _u return _u
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate { func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
_u.mutation.SetStatus(v) _u.mutation.SetStatus(v)
...@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { ...@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedPriority(); ok { if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value) _spec.AddField(account.FieldPriority, field.TypeInt, value)
} }
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value) _spec.SetField(account.FieldStatus, field.TypeString, value)
} }
...@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne { ...@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
return _u return _u
} }
// SetRateMultiplier sets the "rate_multiplier" field.
func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.ResetRateMultiplier()
_u.mutation.SetRateMultiplier(v)
return _u
}
// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
if v != nil {
_u.SetRateMultiplier(*v)
}
return _u
}
// AddRateMultiplier adds value to the "rate_multiplier" field.
func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
_u.mutation.AddRateMultiplier(v)
return _u
}
// SetStatus sets the "status" field. // SetStatus sets the "status" field.
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne { func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
_u.mutation.SetStatus(v) _u.mutation.SetStatus(v)
...@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er ...@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedPriority(); ok { if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value) _spec.AddField(account.FieldPriority, field.TypeInt, value)
} }
if value, ok := _u.mutation.RateMultiplier(); ok {
_spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
}
if value, ok := _u.mutation.Status(); ok { if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value) _spec.SetField(account.FieldStatus, field.TypeString, value)
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/account"
...@@ -25,6 +26,7 @@ type AccountGroupQuery struct { ...@@ -25,6 +26,7 @@ type AccountGroupQuery struct {
predicates []predicate.AccountGroup predicates []predicate.AccountGroup
withAccount *AccountQuery withAccount *AccountQuery
withGroup *GroupQuery withGroup *GroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -347,6 +349,9 @@ func (_q *AccountGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] ...@@ -347,6 +349,9 @@ func (_q *AccountGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -432,6 +437,9 @@ func (_q *AccountGroupQuery) loadGroup(ctx context.Context, query *GroupQuery, n ...@@ -432,6 +437,9 @@ func (_q *AccountGroupQuery) loadGroup(ctx context.Context, query *GroupQuery, n
func (_q *AccountGroupQuery) sqlCount(ctx context.Context) (int, error) { func (_q *AccountGroupQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Unique = false _spec.Unique = false
_spec.Node.Columns = nil _spec.Node.Columns = nil
return sqlgraph.CountNodes(ctx, _q.driver, _spec) return sqlgraph.CountNodes(ctx, _q.driver, _spec)
...@@ -495,6 +503,9 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -495,6 +503,9 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -512,6 +523,32 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -512,6 +523,32 @@ func (_q *AccountGroupQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *AccountGroupQuery) ForUpdate(opts ...sql.LockOption) *AccountGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *AccountGroupQuery) ForShare(opts ...sql.LockOption) *AccountGroupQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// AccountGroupGroupBy is the group-by builder for AccountGroup entities. // AccountGroupGroupBy is the group-by builder for AccountGroup entities.
type AccountGroupGroupBy struct { type AccountGroupGroupBy struct {
selector selector
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package ent package ent
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -35,6 +36,10 @@ type APIKey struct { ...@@ -35,6 +36,10 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field. // Status holds the value of the "status" field.
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
IPBlacklist []string `json:"ip_blacklist,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set. // The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"` Edges APIKeyEdges `json:"edges"`
...@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { ...@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
...@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { ...@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Status = value.String _m.Status = value.String
} }
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
}
}
case apikey.FieldIPBlacklist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
...@@ -245,6 +268,12 @@ func (_m *APIKey) String() string { ...@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("status=") builder.WriteString("status=")
builder.WriteString(_m.Status) builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")
builder.WriteString("ip_blacklist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }
......
...@@ -31,6 +31,10 @@ const ( ...@@ -31,6 +31,10 @@ const (
FieldGroupID = "group_id" FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database. // FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status" FieldStatus = "status"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
FieldIPBlacklist = "ip_blacklist"
// EdgeUser holds the string denoting the user edge name in mutations. // EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user" EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations. // EdgeGroup holds the string denoting the group edge name in mutations.
...@@ -73,6 +77,8 @@ var Columns = []string{ ...@@ -73,6 +77,8 @@ var Columns = []string{
FieldName, FieldName,
FieldGroupID, FieldGroupID,
FieldStatus, FieldStatus,
FieldIPWhitelist,
FieldIPBlacklist,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
......
...@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey { ...@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
} }
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
}
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
func IPWhitelistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
}
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
func IPBlacklistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
}
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
func IPBlacklistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
}
// HasUser applies the HasEdge predicate on the "user" edge. // HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey { func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) { return predicate.APIKey(func(s *sql.Selector) {
......
...@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { ...@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c return _c
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
return _c
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
_c.mutation.SetIPBlacklist(v)
return _c
}
// SetUser sets the "user" edge to the User entity. // SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID) return _c.SetUserID(v.ID)
...@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { ...@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value) _spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value _node.Status = value
} }
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
}
if value, ok := _c.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
_node.IPBlacklist = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.M2O,
...@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { ...@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u return u
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
return u
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPWhitelist)
return u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPWhitelist)
return u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPBlacklist, v)
return u
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPBlacklist)
return u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPBlacklist)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
...@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { ...@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
...@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { ...@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
}) })
} }
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query. // Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"math" "math"
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field" "entgo.io/ent/schema/field"
...@@ -29,6 +30,7 @@ type APIKeyQuery struct { ...@@ -29,6 +30,7 @@ type APIKeyQuery struct {
withUser *UserQuery withUser *UserQuery
withGroup *GroupQuery withGroup *GroupQuery
withUsageLogs *UsageLogQuery withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
...@@ -458,6 +460,9 @@ func (_q *APIKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*APIKe ...@@ -458,6 +460,9 @@ func (_q *APIKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*APIKe
node.Edges.loadedTypes = loadedTypes node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values) return node.assignValues(columns, values)
} }
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks { for i := range hooks {
hooks[i](ctx, _spec) hooks[i](ctx, _spec)
} }
...@@ -583,6 +588,9 @@ func (_q *APIKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, ...@@ -583,6 +588,9 @@ func (_q *APIKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery,
func (_q *APIKeyQuery) sqlCount(ctx context.Context) (int, error) { func (_q *APIKeyQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec() _spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields _spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 { if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
...@@ -651,6 +659,9 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -651,6 +659,9 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique { if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct() selector.Distinct()
} }
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates { for _, p := range _q.predicates {
p(selector) p(selector)
} }
...@@ -668,6 +679,32 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector { ...@@ -668,6 +679,32 @@ func (_q *APIKeyQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector return selector
} }
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *APIKeyQuery) ForUpdate(opts ...sql.LockOption) *APIKeyQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *APIKeyQuery) ForShare(opts ...sql.LockOption) *APIKeyQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// APIKeyGroupBy is the group-by builder for APIKey entities. // APIKeyGroupBy is the group-by builder for APIKey entities.
type APIKeyGroupBy struct { type APIKeyGroupBy struct {
selector selector
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment