Unverified Commit 07bb2a5f authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #952 from xvhuan/feat/billing-ledger-decouple-usage-log-20260312

feat: 解耦计费正确性与 usage_logs 批量写压
parents 417861a4 64b3f3ce
# PR Report: DB 写入热点与后台查询拥塞排查
## 背景
线上在高峰期出现了几类明显症状:
- 管理后台仪表盘接口经常超时,`/api/v1/admin/dashboard/snapshot-v2` 一度达到 50s 以上
- 管理后台充值接口 `/api/v1/admin/users/:id/balance` 出现 15s 以上超时
- 登录态刷新、扣费、错误记录在高峰期出现大量 `context deadline exceeded`
- PostgreSQL 曾出现连接打满,后续回退连接池后,主问题转为 WAL/刷盘拥塞
本报告基于 `/home/ius/sub2api` 当前源码,目标是给出一份可直接拆成 PR 的修复方案。
## 结论
这次故障的主因不是单一“慢 SQL”,而是请求成功路径上的同步写库次数过多,叠加部分后台查询仍直接扫 `usage_logs`,最终把 PostgreSQL 的 WAL 刷盘、热点行更新和 outbox 重建链路一起放大。
代码层面的核心问题有 6 个。
### 1. 成功请求路径同步写库过多
`backend/internal/service/gateway_service.go:6594``postUsageBilling` 在单次请求成功后,可能同步触发以下写操作:
- `userRepo.DeductBalance`
- `APIKeyService.UpdateQuotaUsed`
- `APIKeyService.UpdateRateLimitUsage`
- `accountRepo.IncrementQuotaUsed`
- `deferredService.ScheduleLastUsedUpdate`(这一项已经做了延迟批量,是正确方向)
也就是说,一次成功请求不是 1 次落库,而是 3 到 5 次写入。
这和线上看到的现象是吻合的:
- `UPDATE accounts SET extra = ...`
- `INSERT INTO usage_logs ...`
- `INSERT INTO ops_error_logs ...`
- `scheduler_outbox` backlog
### 2. API Key 配额更新存在额外读写放大
`backend/internal/service/api_key_service.go:815``UpdateQuotaUsed` 当前流程是:
1. `IncrementQuotaUsed`
2. `GetByID`
3. 如超限再 `Update`
对应仓储实现:
- `backend/internal/repository/api_key_repo.go:441` 只做自增
- 然后 service 再回表读取完整 API Key
- 之后可能再整行更新状态
这让“每次扣费后更新 API Key 配额”从 1 条 SQL 变成了最多 3 次数据库交互。
### 3. `accounts.extra` 被当成高频热写字段使用
两个最重的热点都落在 `accounts.extra`
- `backend/internal/repository/account_repo.go:1159` `UpdateExtra`
- `backend/internal/repository/account_repo.go:1683` `IncrementQuotaUsed`
问题有两个:
1. 两者都会重写整块 JSONB,并更新 `updated_at`
2. `UpdateExtra` 每次写完都会额外插入一条 `scheduler_outbox`
尤其 `UpdateExtra` 现在被多处高频调用:
- `backend/internal/service/openai_gateway_service.go:4039` 持久化 Codex rate-limit snapshot
- `backend/internal/service/ratelimit_service.go:903` 持久化 OpenAI Codex snapshot
- `backend/internal/service/ratelimit_service.go:1013` / `1025` 更新 session window utilization
这类“监控/额度快照”并不会改变账号是否可调度,却仍然走了:
- JSONB 更新
- `updated_at`
- `scheduler_outbox`
这是明显的写放大。
### 4. `scheduler_outbox` 设计偏向“每次状态变更都写一条”,高峰期会反压调度器
`backend/internal/repository/scheduler_outbox_repo.go:79``enqueueSchedulerOutbox` 非常轻,但它被大量调用。
例如:
- `UpdateExtra` 每次都 enqueue `AccountChanged`
- `BatchUpdateLastUsed` 也会 enqueue 一条 `AccountLastUsed`
- 各类账号限流、过载、错误状态切换也都会 enqueue
对应的 outbox worker 在:
- `backend/internal/service/scheduler_snapshot_service.go:199`
- `backend/internal/service/scheduler_snapshot_service.go:219`
它会不断拉取 outbox,再触发 `GetByID``rebuildBucket``loadAccountsFromDB`
所以当高频写入导致 outbox 增长时,系统不仅多了写,还会反向带出更多读和缓存重建。
### 5. 仪表盘只有一部分走了预聚合,`models/groups/users-trend` 仍然直接扫 `usage_logs`
好消息是,`dashboard stats` 本身已经接了预聚合表:
- `backend/internal/repository/usage_log_repo.go:306`
- `backend/internal/repository/usage_log_repo.go:420`
- 预聚合表定义在 `backend/migrations/034_usage_dashboard_aggregation_tables.sql:1`
但后台慢的不是只有 stats。
`snapshot-v2` 默认会同时拉:
- stats
- trend
- model stats
见:
- `backend/internal/handler/admin/dashboard_snapshot_v2_handler.go:68`
其中:
- `GetUsageTrendWithFilters` 只有“无过滤、day/hour”时才走预聚合,见 `usage_log_repo.go:1657`
- `GetModelStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1805`
- `GetGroupStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1872`
- `GetUserUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1101`
- `GetAPIKeyUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1046`
所以线上会出现:
- stats 快
-`snapshot-v2` 仍然慢
- `/admin/dashboard/users-trend` 单独也慢
这和你线上看到的日志完全一致。
### 6. 管理后台充值是“读用户 -> 整体更新用户 -> 插审计记录”
`backend/internal/service/admin_service.go:694``UpdateUserBalance` 当前流程:
1. `GetByID`
2. 在内存里改 balance
3. `userRepo.Update`
4. `redeemCodeRepo.Create` 记录 admin 调账历史
`userRepo.Update` 是整用户对象更新,并同步 allowed groups 事务处理:
- `backend/internal/repository/user_repo.go:118`
这个接口平时不一定重,但在数据库已经抖动时,会比一个原子 `UPDATE users SET balance = balance + $1` 更脆弱。
## 额外观察
### `ops_error_logs` 虽然已异步化,但单条写入仍然很重
错误日志中间件已经做了队列削峰:
- `backend/internal/handler/ops_error_logger.go:69`
- `backend/internal/handler/ops_error_logger.go:106`
这点方向是对的。
但落库表本身很重:
- `backend/internal/repository/ops_repo.go:23`
- `backend/migrations/033_ops_monitoring_vnext.sql:69`
- `backend/migrations/033_ops_monitoring_vnext.sql:470`
`ops_error_logs` 不仅列很多,还带了多组 B-Tree 和 trigram 索引。高错误率时,即使改成异步,也还是会把 WAL 和 I/O 压上去。
## 建议的 PR 拆分
建议拆成 4 个 PR,不要在一个 PR 里同时改数据库模型、后台查询和管理接口。
### PR 1: 收缩成功请求路径的同步写库次数
目标:把一次成功请求的同步写次数从 3 到 5 次,尽量压到 1 到 2 次。
建议改动:
1.`APIKeyService.UpdateQuotaUsed` 改为单 SQL
- 新增 repo 方法,例如 `IncrementQuotaUsedAndMaybeExhaust`
- 在 SQL 里同时完成 `quota_used += ?``status = quota_exhausted`
- 返回 `key/status/quota/quota_used` 最小字段,直接失效缓存
- 删掉当前的 `Increment -> GetByID -> Update`
2. 把账号 quota 计数从 `accounts.extra` 拆出去
- 最理想:新增结构化列或独立 `account_quota_counters`
- 次优:至少把 `quota_used/quota_daily_used/quota_weekly_used` 从 JSONB 中剥离
3. 对“纯监控型 extra 字段”禁止 enqueue outbox
- 例如 codex snapshot、session_window_utilization
- 这些字段不影响调度,不应该触发 `SchedulerOutboxEventAccountChanged`
4. 复用现有 `DeferredService` 思路
- `last_used` 已经是批量刷盘,见 `deferred_service.go:41`
- 可继续扩展 `deferred quota snapshot flush`
预期收益:
- 直接减少 WAL 写入量
- 降低 `accounts` 热点行锁竞争
- 降低 outbox 增长速度
### PR 2: 给 dashboard 补齐预聚合/缓存,避免继续扫 `usage_logs`
目标:后台仪表盘接口不再直接扫描大表。
建议改动:
1.`users-trend` / `api-keys-trend` 增加小时/天级预聚合表
2.`model stats` / `group stats` 增加日级聚合表
3. `snapshot-v2` 增加分段缓存
- `stats`
- `trend`
- `models`
- `groups`
- `users_trend`
避免一个 section miss 导致整份 snapshot 重新扫库
4. 可选:把 `include_model_stats` 默认值从 `true` 改成 `false`
- 至少让默认仪表盘先恢复可用,再按需加载重模块
预期收益:
- `snapshot-v2`
- `/admin/dashboard/users-trend`
- `/admin/dashboard/api-keys-trend`
这几类接口会从“随数据量线性恶化”变成“近似固定成本”。
### PR 3: 简化管理后台充值链路
目标:管理充值/扣余额不再依赖整用户对象更新。
建议改动:
1. 新增 repo 原子方法
- `SetBalance(userID, amount)`
- `AddBalance(userID, delta)`
- `SubtractBalance(userID, delta)`
2. `UpdateUserBalance` 改为:
- 先执行原子 SQL
- 再读一次最小必要字段返回
- 审计记录改为异步或降级写
3. 审计记录建议改名或独立表
- 现在把后台调账记录塞进 `redeem_codes`,语义上不干净
预期收益:
- `/api/v1/admin/users/:id/balance` 在库抖时更稳
- 失败面缩小,不再被 allowed groups 同步事务拖累
### PR 4: 为重写路径增加“丢弃策略”和“熔断指标”
目标:高峰期先保护主链路,不让非核心写入拖死数据库。
建议改动:
1. `ops_error_logs`
- 增加采样或分级开关
- 对重复 429/5xx 做聚合计数而不是逐条插入
- 对 request body / headers 存储加更严格开关
2. `scheduler_outbox`
- 增加 coalesce/merge 机制
- 同一账号短时间内多次 `AccountChanged` 合并为一条
3. 指标补齐
- outbox backlog
- ops error queue dropped
- deferred flush lag
- account extra write QPS
## 推荐实施顺序
1. 先做 PR 1
- 这是这次线上故障的主链路
2. 再做 PR 2
- 解决后台仪表盘慢
3. 再做 PR 3
- 解决后台充值接口脆弱
4. 最后做 PR 4
- 做长期保护
## 验证方案
每个 PR 合并前都建议做同一组验证:
1. 压测成功请求链路,记录单请求 SQL 次数
2. 观测 PostgreSQL:
- `pg_stat_activity`
- `pg_stat_statements`
- `WALWrite` / `WalSync`
- 每分钟 WAL 增量
3. 观测接口:
- `/api/v1/auth/refresh`
- `/api/v1/admin/dashboard/snapshot-v2`
- `/api/v1/admin/dashboard/users-trend`
- `/api/v1/admin/users/:id/balance`
4. 观测队列:
- `ops_error_logs` queue length / dropped
- `scheduler_outbox` backlog
## 可直接作为 PR 描述的摘要
This PR reduces database write amplification on the request success path and removes several hot-path writes from `accounts.extra` + `scheduler_outbox`. It also prepares dashboard endpoints to rely on pre-aggregated data instead of scanning `usage_logs` under load. The goal is to keep admin dashboard, balance update, auth refresh, and billing-related paths stable under sustained 500+ RPS traffic.
......@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
......@@ -163,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
......@@ -935,6 +935,7 @@ type DashboardAggregationConfig struct {
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
......@@ -1301,6 +1302,7 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
......@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
......@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}
......
......@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
......@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation dedup retention",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation dedup retention smaller than usage logs",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageLogsDays = 30
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
......
......@@ -434,6 +434,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
......@@ -445,6 +446,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
......@@ -736,6 +738,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
......@@ -747,6 +750,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
......
......@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo
......
......@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
......@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,
......
......@@ -352,6 +352,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
......@@ -363,6 +364,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
......@@ -732,6 +734,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
......@@ -742,6 +745,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
......@@ -1238,6 +1242,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
......
......@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}
......
......@@ -399,6 +399,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
......@@ -410,6 +411,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
......
......@@ -431,6 +431,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,
......
......@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
......@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
......@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
......@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
......
......@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.Quota != 0 {
create.SetQuota(k.Quota)
}
if k.QuotaUsed != 0 {
create.SetQuotaUsed(k.QuotaUsed)
}
if k.RateLimit5h != 0 {
create.SetRateLimit5h(k.RateLimit5h)
}
if k.RateLimit1d != 0 {
create.SetRateLimit1d(k.RateLimit1d)
}
if k.RateLimit7d != 0 {
create.SetRateLimit7d(k.RateLimit7d)
}
if k.Usage5h != 0 {
create.SetUsage5h(k.Usage5h)
}
if k.Usage1d != 0 {
create.SetUsage1d(k.Usage1d)
}
if k.Usage7d != 0 {
create.SetUsage7d(k.Usage7d)
}
if k.Window5hStart != nil {
create.SetWindow5hStart(*k.Window5hStart)
}
if k.Window1dStart != nil {
create.SetWindow1dStart(*k.Window1dStart)
}
if k.Window7dStart != nil {
create.SetWindow7dStart(*k.Window7dStart)
}
if k.ExpiresAt != nil {
create.SetExpiresAt(*k.ExpiresAt)
}
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}
......
......@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
......@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
......
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}
......@@ -3,10 +3,14 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
......@@ -15,15 +19,56 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
"bigint",
"bigint",
"text",
"text",
"bigint",
"bigint",
"integer",
"integer",
"integer",
"integer",
"integer",
"integer",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"numeric",
"smallint",
"smallint",
"boolean",
"boolean",
"integer",
"integer",
"text",
"text",
"integer",
"text",
"text",
"text",
"text",
"boolean",
"timestamptz",
}
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
......@@ -43,15 +88,89 @@ func safeDateFormat(granularity string) string {
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
db *sql.DB
createBatchOnce sync.Once
createBatchCh chan usageLogCreateRequest
bestEffortBatchOnce sync.Once
bestEffortBatchCh chan usageLogBestEffortRequest
bestEffortRecent *gocache.Cache
}
const (
usageLogCreateBatchMaxSize = 64
usageLogCreateBatchWindow = 3 * time.Millisecond
usageLogCreateBatchQueueCap = 4096
usageLogCreateCancelWait = 2 * time.Second
usageLogBestEffortBatchMaxSize = 256
usageLogBestEffortBatchWindow = 20 * time.Millisecond
usageLogBestEffortBatchQueueCap = 32768
usageLogBestEffortRecentTTL = 30 * time.Second
)
type usageLogCreateRequest struct {
log *service.UsageLog
prepared usageLogInsertPrepared
shared *usageLogCreateShared
resultCh chan usageLogCreateResult
}
type usageLogCreateResult struct {
inserted bool
err error
}
type usageLogBestEffortRequest struct {
prepared usageLogInsertPrepared
apiKeyID int64
resultCh chan error
}
type usageLogInsertPrepared struct {
createdAt time.Time
requestID string
rateMultiplier float64
requestType int16
args []any
}
type usageLogBatchState struct {
ID int64
CreatedAt time.Time
}
type usageLogBatchRow struct {
RequestID string `json:"request_id"`
APIKeyID int64 `json:"api_key_id"`
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Inserted bool `json:"inserted"`
}
type usageLogCreateShared struct {
state atomic.Int32
}
const (
usageLogCreateStateQueued int32 = iota
usageLogCreateStateProcessing
usageLogCreateStateCompleted
usageLogCreateStateCanceled
)
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
return newUsageLogRepositoryWithSQL(client, sqlDB)
}
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq}
repo := &usageLogRepository{client: client, sql: sqlq}
if db, ok := sqlq.(*sql.DB); ok {
repo.db = db
}
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
return repo
}
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
......@@ -82,24 +201,72 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
return r.createSingle(ctx, tx.Client(), log)
}
requestID := strings.TrimSpace(log.RequestID)
if requestID == "" {
return r.createSingle(ctx, r.sql, log)
}
log.RequestID = requestID
return r.createBatched(ctx, log)
}
createdAt := log.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
if log == nil {
return nil
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
if tx := dbent.TxFromContext(ctx); tx != nil {
_, err := r.createSingle(ctx, tx.Client(), log)
return err
}
if r.db == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
r.ensureBestEffortBatcher()
if r.bestEffortBatchCh == nil {
_, err := r.createSingle(ctx, r.sql, log)
return err
}
req := usageLogBestEffortRequest{
prepared: prepareUsageLogInsert(log),
apiKeyID: log.APIKeyID,
resultCh: make(chan error, 1),
}
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
if _, exists := r.bestEffortRecent.Get(key); exists {
return nil
}
}
select {
case r.bestEffortBatchCh <- req:
case <-ctx.Done():
return service.MarkUsageLogCreateDropped(ctx.Err())
default:
return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full"))
}
select {
case err := <-req.resultCh:
return err
case <-ctx.Done():
return service.MarkUsageLogCreateDropped(ctx.Err())
}
}
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
prepared := prepareUsageLogInsert(log)
if sqlq == nil {
sqlq = r.sql
}
if ctx != nil && ctx.Err() != nil {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
query := `
INSERT INTO usage_logs (
......@@ -151,76 +318,863 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
RETURNING id, created_at
`
groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
args := []any{
log.UserID,
log.APIKeyID,
log.AccountID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
log.RateMultiplier = prepared.rateMultiplier
return false, nil
} else {
return false, err
}
}
log.RateMultiplier = rateMultiplier
log.RateMultiplier = prepared.rateMultiplier
return true, nil
}
func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) {
if r.db == nil {
return r.createSingle(ctx, r.sql, log)
}
r.ensureCreateBatcher()
if r.createBatchCh == nil {
return r.createSingle(ctx, r.sql, log)
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
select {
case r.createBatchCh <- req:
case <-ctx.Done():
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
default:
return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full"))
}
select {
case res := <-req.resultCh:
return res.inserted, res.err
case <-ctx.Done():
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
}
timer := time.NewTimer(usageLogCreateCancelWait)
defer timer.Stop()
select {
case res := <-req.resultCh:
return res.inserted, res.err
case <-timer.C:
return false, ctx.Err()
}
}
}
func (r *usageLogRepository) ensureCreateBatcher() {
if r == nil || r.db == nil || r.createBatchCh != nil {
return
}
r.createBatchOnce.Do(func() {
r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap)
go r.runCreateBatcher(r.db)
})
}
func (r *usageLogRepository) ensureBestEffortBatcher() {
if r == nil || r.db == nil || r.bestEffortBatchCh != nil {
return
}
r.bestEffortBatchOnce.Do(func() {
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
go r.runBestEffortBatcher(r.db)
})
}
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
for {
first, ok := <-r.createBatchCh
if !ok {
return
}
batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize)
batch = append(batch, first)
timer := time.NewTimer(usageLogCreateBatchWindow)
batchLoop:
for len(batch) < usageLogCreateBatchMaxSize {
select {
case req, ok := <-r.createBatchCh:
if !ok {
break batchLoop
}
batch = append(batch, req)
case <-timer.C:
break batchLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
r.flushCreateBatch(db, batch)
}
}
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
for {
first, ok := <-r.bestEffortBatchCh
if !ok {
return
}
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
batch = append(batch, first)
timer := time.NewTimer(usageLogBestEffortBatchWindow)
bestEffortLoop:
for len(batch) < usageLogBestEffortBatchMaxSize {
select {
case req, ok := <-r.bestEffortBatchCh:
if !ok {
break bestEffortLoop
}
batch = append(batch, req)
case <-timer.C:
break bestEffortLoop
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
r.flushBestEffortBatch(db, batch)
}
}
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
if len(batch) == 0 {
return
}
uniqueOrder := make([]string, 0, len(batch))
preparedByKey := make(map[string]usageLogInsertPrepared, len(batch))
requestsByKey := make(map[string][]usageLogCreateRequest, len(batch))
fallback := make([]usageLogCreateRequest, 0)
for _, req := range batch {
if req.log == nil {
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
continue
}
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
if req.shared.state.Load() == usageLogCreateStateCanceled {
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
})
continue
}
}
prepared := req.prepared
if prepared.requestID == "" {
fallback = append(fallback, req)
continue
}
key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID)
if _, exists := requestsByKey[key]; !exists {
uniqueOrder = append(uniqueOrder, key)
preparedByKey[key] = prepared
}
requestsByKey[key] = append(requestsByKey[key], req)
}
if len(uniqueOrder) > 0 {
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
if err != nil {
if safeFallback {
for _, key := range uniqueOrder {
fallback = append(fallback, requestsByKey[key]...)
}
} else {
for _, key := range uniqueOrder {
reqs := requestsByKey[key]
state, hasState := stateMap[key]
inserted := insertedMap[key]
for idx, req := range reqs {
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
if hasState {
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
}
switch {
case inserted && idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
case inserted:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case hasState:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
case idx == 0:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
default:
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
}
}
}
}
} else {
for _, key := range uniqueOrder {
reqs := requestsByKey[key]
state, ok := stateMap[key]
if !ok {
for _, req := range reqs {
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: false,
err: fmt.Errorf("usage log batch state missing for key=%s", key),
})
}
continue
}
for idx, req := range reqs {
req.log.ID = state.ID
req.log.CreatedAt = state.CreatedAt
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
completeUsageLogCreateRequest(req, usageLogCreateResult{
inserted: idx == 0 && insertedMap[key],
err: nil,
})
}
}
}
}
if len(fallback) == 0 {
return
}
for _, req := range fallback {
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
inserted, err := r.createSingle(fallbackCtx, db, req.log)
cancel()
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
}
}
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
if len(batch) == 0 {
return
}
type bestEffortGroup struct {
prepared usageLogInsertPrepared
apiKeyID int64
key string
reqs []usageLogBestEffortRequest
}
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
groupOrder := make([]*bestEffortGroup, 0, len(batch))
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
for idx, req := range batch {
prepared := req.prepared
key := fmt.Sprintf("__best_effort_%d", idx)
if prepared.requestID != "" {
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
}
group, exists := groupsByKey[key]
if !exists {
group = &bestEffortGroup{
prepared: prepared,
apiKeyID: req.apiKeyID,
key: key,
}
groupsByKey[key] = group
groupOrder = append(groupOrder, group)
preparedList = append(preparedList, prepared)
}
group.reqs = append(group.reqs, req)
}
if len(preparedList) == 0 {
for _, req := range batch {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
if _, err := db.ExecContext(ctx, query, args...); err != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
for _, group := range groupOrder {
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
if singleErr != nil {
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, singleErr)
}
}
return
}
for _, group := range groupOrder {
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
r.bestEffortRecent.SetDefault(group.key, struct{}{})
}
for _, req := range group.reqs {
sendUsageLogBestEffortResult(req.resultCh, nil)
}
}
}
func sendUsageLogBestEffortResult(ch chan error, err error) {
if ch == nil {
return
}
select {
case ch <- err:
default:
}
}
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
if req.shared != nil {
req.shared.state.Store(usageLogCreateStateCompleted)
}
sendUsageLogCreateResult(req.resultCh, res)
}
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
if len(keys) == 0 {
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
var payload []byte
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
return nil, nil, true, err
}
var rows []usageLogBatchRow
if err := json.Unmarshal(payload, &rows); err != nil {
return nil, nil, false, err
}
insertedMap := make(map[string]bool, len(keys))
stateMap := make(map[string]usageLogBatchState, len(keys))
for _, row := range rows {
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
insertedMap[key] = row.Inserted
stateMap[key] = usageLogBatchState{
ID: row.ID,
CreatedAt: row.CreatedAt,
}
}
if len(stateMap) != len(keys) {
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
}
return insertedMap, stateMap, false, nil
}
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
input_idx,
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*37)
argPos := 1
for idx, key := range keys {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
args = append(args, idx)
argPos++
prepared := preparedByKey[key]
for i := 0; i < len(prepared.args); i++ {
_, _ = query.WriteString(",")
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
if i < len(usageLogInsertArgTypes) {
_, _ = query.WriteString("::")
_, _ = query.WriteString(usageLogInsertArgTypes[i])
}
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
),
inserted AS (
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING request_id, api_key_id, id, created_at
),
resolved AS (
SELECT
input.input_idx,
input.request_id,
input.api_key_id,
COALESCE(inserted.id, existing.id) AS id,
COALESCE(inserted.created_at, existing.created_at) AS created_at,
(inserted.id IS NOT NULL) AS inserted
FROM input
LEFT JOIN inserted
ON inserted.request_id = input.request_id
AND inserted.api_key_id = input.api_key_id
LEFT JOIN usage_logs existing
ON existing.request_id = input.request_id
AND existing.api_key_id = input.api_key_id
)
SELECT COALESCE(
json_agg(
json_build_object(
'request_id', resolved.request_id,
'api_key_id', resolved.api_key_id,
'id', resolved.id,
'created_at', resolved.created_at,
'inserted', resolved.inserted
)
ORDER BY resolved.input_idx
),
'[]'::json
)
FROM resolved
`)
return query.String(), args
}
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
var query strings.Builder
_, _ = query.WriteString(`
WITH input (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("(")
for i := 0; i < len(prepared.args); i++ {
if i > 0 {
_, _ = query.WriteString(",")
}
_, _ = query.WriteString("$")
_, _ = query.WriteString(strconv.Itoa(argPos))
if i < len(usageLogInsertArgTypes) {
_, _ = query.WriteString("::")
_, _ = query.WriteString(usageLogInsertArgTypes[i])
}
argPos++
}
_, _ = query.WriteString(")")
args = append(args, prepared.args...)
}
_, _ = query.WriteString(`
)
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
)
SELECT
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
`)
return query.String(), args
}
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
_, err := sqlq.ExecContext(ctx, `
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
media_type,
service_tier,
reasoning_effort,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
return err
}
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
createdAt := log.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
}
requestID := strings.TrimSpace(log.RequestID)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any
if requestID != "" {
requestIDArg = requestID
}
return usageLogInsertPrepared{
createdAt: createdAt,
requestID: requestID,
rateMultiplier: rateMultiplier,
requestType: requestType,
args: []any{
log.UserID,
log.APIKeyID,
log.AccountID,
requestIDArg,
log.Model,
groupID,
subscriptionID,
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
mediaType,
serviceTier,
reasoningEffort,
log.CacheTTLOverridden,
createdAt,
},
}
}
func usageLogBatchKey(requestID string, apiKeyID int64) string {
return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10)
}
func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateResult) {
if ch == nil {
return
}
select {
case ch <- res:
default:
}
}
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
requestID = strings.TrimSpace(requestID)
if requestID == "" || r == nil || r.bestEffortRecent == nil {
return "", false
}
return usageLogBatchKey(requestID, apiKeyID), true
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)
......
......@@ -4,6 +4,8 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
......@@ -14,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
......@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
......
......@@ -3,8 +3,11 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
......@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}
......@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment