Unverified Commit 45bd9ac7 authored by IanShaw's avatar IanShaw Committed by GitHub
Browse files

运维监控系统安全加固和功能优化 (#21)

* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
parent 7fdc2b2d
...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/keys (paginated)", name: "GET /api/v1/keys (paginated)",
setup: func(t *testing.T, deps *contractDeps) { setup: func(t *testing.T, deps *contractDeps) {
t.Helper() t.Helper()
deps.apiKeyRepo.MustSeed(&service.ApiKey{ deps.apiKeyRepo.MustSeed(&service.APIKey{
ID: 100, ID: 100,
UserID: 1, UserID: 1,
Key: "sk_custom_1234567890", Key: "sk_custom_1234567890",
...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 10, InputTokens: 10,
...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 2, ID: 2,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
Model: "claude-3", Model: "claude-3",
InputTokens: 5, InputTokens: 5,
...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -188,7 +188,7 @@ func TestAPIContracts(t *testing.T) {
{ {
ID: 1, ID: 1,
UserID: 1, UserID: 1,
ApiKeyID: 100, APIKeyID: 100,
AccountID: 200, AccountID: 200,
RequestID: "req_123", RequestID: "req_123",
Model: "claude-3", Model: "claude-3",
...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) { ...@@ -259,13 +259,13 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false", service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeySmtpHost: "smtp.example.com", service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySmtpPort: "587", service.SettingKeySMTPPort: "587",
service.SettingKeySmtpUsername: "user", service.SettingKeySMTPUsername: "user",
service.SettingKeySmtpPassword: "secret", service.SettingKeySMTPPassword: "secret",
service.SettingKeySmtpFrom: "no-reply@example.com", service.SettingKeySMTPFrom: "no-reply@example.com",
service.SettingKeySmtpFromName: "Sub2API", service.SettingKeySMTPFromName: "Sub2API",
service.SettingKeySmtpUseTLS: "true", service.SettingKeySMTPUseTLS: "true",
service.SettingKeyTurnstileEnabled: "true", service.SettingKeyTurnstileEnabled: "true",
service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSiteKey: "site-key",
...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) { ...@@ -274,9 +274,9 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeySiteName: "Sub2API", service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "", service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle", service.SettingKeySiteSubtitle: "Subtitle",
service.SettingKeyApiBaseUrl: "https://api.example.com", service.SettingKeyAPIBaseURL: "https://api.example.com",
service.SettingKeyContactInfo: "support", service.SettingKeyContactInfo: "support",
service.SettingKeyDocUrl: "https://docs.example.com", service.SettingKeyDocURL: "https://docs.example.com",
service.SettingKeyDefaultConcurrency: "5", service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25", service.SettingKeyDefaultBalance: "1.25",
...@@ -331,7 +331,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -331,7 +331,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct { type contractDeps struct {
now time.Time now time.Time
router http.Handler router http.Handler
apiKeyRepo *stubApiKeyRepo apiKeyRepo *stubAPIKeyRepo
usageRepo *stubUsageLogRepo usageRepo *stubUsageLogRepo
settingRepo *stubSettingRepo settingRepo *stubSettingRepo
} }
...@@ -359,20 +359,20 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -359,20 +359,20 @@ func newContractDeps(t *testing.T) *contractDeps {
}, },
} }
apiKeyRepo := newStubApiKeyRepo(now) apiKeyRepo := newStubAPIKeyRepo(now)
apiKeyCache := stubApiKeyCache{} apiKeyCache := stubAPIKeyCache{}
groupRepo := stubGroupRepo{} groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{} userSubRepo := stubUserSubscriptionRepo{}
cfg := &config.Config{ cfg := &config.Config{
Default: config.DefaultConfig{ Default: config.DefaultConfig{
ApiKeyPrefix: "sk-", APIKeyPrefix: "sk-",
}, },
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo) userService := service.NewUserService(userRepo)
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo) usageService := service.NewUsageService(usageRepo, userRepo)
...@@ -525,25 +525,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID ...@@ -525,25 +525,25 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
type stubApiKeyCache struct{} type stubAPIKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (stubAPIKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil return 0, nil
} }
func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { func (stubAPIKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil return nil
} }
func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { func (stubAPIKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil return nil
} }
func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { func (stubAPIKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil return nil
} }
func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { func (stubAPIKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil return nil
} }
...@@ -660,24 +660,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i ...@@ -660,24 +660,24 @@ func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (i
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
type stubApiKeyRepo struct { type stubAPIKeyRepo struct {
now time.Time now time.Time
nextID int64 nextID int64
byID map[int64]*service.ApiKey byID map[int64]*service.APIKey
byKey map[string]*service.ApiKey byKey map[string]*service.APIKey
} }
func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo { func newStubAPIKeyRepo(now time.Time) *stubAPIKeyRepo {
return &stubApiKeyRepo{ return &stubAPIKeyRepo{
now: now, now: now,
nextID: 100, nextID: 100,
byID: make(map[int64]*service.ApiKey), byID: make(map[int64]*service.APIKey),
byKey: make(map[string]*service.ApiKey), byKey: make(map[string]*service.APIKey),
} }
} }
func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { func (r *stubAPIKeyRepo) MustSeed(key *service.APIKey) {
if key == nil { if key == nil {
return return
} }
...@@ -686,7 +686,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { ...@@ -686,7 +686,7 @@ func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) {
r.byKey[clone.Key] = &clone r.byKey[clone.Key] = &clone
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
...@@ -706,38 +706,38 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error ...@@ -706,38 +706,38 @@ func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error
return nil return nil
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *key clone := *key
return &clone, nil return &clone, nil
} }
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return 0, service.ErrApiKeyNotFound return 0, service.ErrAPIKeyNotFound
} }
return key.UserID, nil return key.UserID, nil
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
found, ok := r.byKey[key] found, ok := r.byKey[key]
if !ok { if !ok {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *found clone := *found
return &clone, nil return &clone, nil
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil { if key == nil {
return errors.New("nil key") return errors.New("nil key")
} }
if _, ok := r.byID[key.ID]; !ok { if _, ok := r.byID[key.ID]; !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
if key.UpdatedAt.IsZero() { if key.UpdatedAt.IsZero() {
key.UpdatedAt = r.now key.UpdatedAt = r.now
...@@ -748,17 +748,17 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error ...@@ -748,17 +748,17 @@ func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error
return nil return nil
} }
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
key, ok := r.byID[id] key, ok := r.byID[id]
if !ok { if !ok {
return service.ErrApiKeyNotFound return service.ErrAPIKeyNotFound
} }
delete(r.byID, id) delete(r.byID, id)
delete(r.byKey, key.Key) delete(r.byKey, key.Key)
return nil return nil
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
ids := make([]int64, 0, len(r.byID)) ids := make([]int64, 0, len(r.byID))
for id := range r.byID { for id := range r.byID {
if r.byID[id].UserID == userID { if r.byID[id].UserID == userID {
...@@ -776,7 +776,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params ...@@ -776,7 +776,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
end = len(ids) end = len(ids)
} }
out := make([]service.ApiKey, 0, end-start) out := make([]service.APIKey, 0, end-start)
for _, id := range ids[start:end] { for _, id := range ids[start:end] {
clone := *r.byID[id] clone := *r.byID[id]
out = append(out, clone) out = append(out, clone)
...@@ -796,7 +796,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params ...@@ -796,7 +796,7 @@ func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params
}, nil }, nil
} }
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
...@@ -815,7 +815,7 @@ func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK ...@@ -815,7 +815,7 @@ func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiK
return out, nil return out, nil
} }
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
for _, key := range r.byID { for _, key := range r.byID {
if key.UserID == userID { if key.UserID == userID {
...@@ -825,24 +825,24 @@ func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64 ...@@ -825,24 +825,24 @@ func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64
return count, nil return count, nil
} }
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
_, ok := r.byKey[key] _, ok := r.byKey[key]
return ok, nil return ok, nil
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
...@@ -877,7 +877,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params ...@@ -877,7 +877,7 @@ func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params
return out, paginationResult(total, params), nil return out, paginationResult(total, params), nil
} }
func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -890,7 +890,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in ...@@ -890,7 +890,7 @@ func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID in
return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil
} }
func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
...@@ -922,7 +922,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -922,7 +922,7 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -975,7 +975,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in ...@@ -975,7 +975,7 @@ func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID in
}, nil }, nil
} }
func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { func (r *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -995,7 +995,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -995,7 +995,7 @@ func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs [
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -1017,8 +1017,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio ...@@ -1017,8 +1017,8 @@ func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params paginatio
// Apply filters // Apply filters
var filtered []service.UsageLog var filtered []service.UsageLog
for _, log := range logs { for _, log := range logs {
// Apply ApiKeyID filter // Apply APIKeyID filter
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID { if filters.APIKeyID > 0 && log.APIKeyID != filters.APIKeyID {
continue continue
} }
// Apply Model filter // Apply Model filter
...@@ -1151,8 +1151,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati ...@@ -1151,8 +1151,8 @@ func paginationResult(total int64, params pagination.PaginationParams) *paginati
// Ensure compile-time interface compliance. // Ensure compile-time interface compliance.
var ( var (
_ service.UserRepository = (*stubUserRepo)(nil) _ service.UserRepository = (*stubUserRepo)(nil)
_ service.ApiKeyRepository = (*stubApiKeyRepo)(nil) _ service.APIKeyRepository = (*stubAPIKeyRepo)(nil)
_ service.ApiKeyCache = (*stubApiKeyCache)(nil) _ service.APIKeyCache = (*stubAPIKeyCache)(nil)
_ service.GroupRepository = (*stubGroupRepo)(nil) _ service.GroupRepository = (*stubGroupRepo)(nil)
_ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil)
_ service.UsageLogRepository = (*stubUsageLogRepo)(nil) _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
......
// Package server provides HTTP server setup and routing configuration.
package server package server
import ( import (
...@@ -25,8 +26,8 @@ func ProvideRouter( ...@@ -25,8 +26,8 @@ func ProvideRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
) *gin.Engine { ) *gin.Engine {
if cfg.Server.Mode == "release" { if cfg.Server.Mode == "release" {
......
...@@ -32,7 +32,7 @@ func adminAuth( ...@@ -32,7 +32,7 @@ func adminAuth(
// 检查 x-api-key header(Admin API Key 认证) // 检查 x-api-key header(Admin API Key 认证)
apiKey := c.GetHeader("x-api-key") apiKey := c.GetHeader("x-api-key")
if apiKey != "" { if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) { if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return return
} }
c.Next() c.Next()
...@@ -52,19 +52,48 @@ func adminAuth( ...@@ -52,19 +52,48 @@ func adminAuth(
} }
} }
// WebSocket 请求无法设置自定义 header,允许在 query 中携带凭证
if isWebSocketRequest(c) {
if token := strings.TrimSpace(c.Query("token")); token != "" {
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
return
}
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息 // 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
} }
} }
// validateAdminApiKey 验证管理员 API Key func isWebSocketRequest(c *gin.Context) bool {
func validateAdminApiKey( if c == nil || c.Request == nil {
return false
}
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
return true
}
conn := strings.ToLower(c.GetHeader("Connection"))
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
}
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context, c *gin.Context,
key string, key string,
settingService *service.SettingService, settingService *service.SettingService,
userService *service.UserService, userService *service.UserService,
) bool { ) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context()) storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
if err != nil { if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error") AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false return false
......
...@@ -11,13 +11,13 @@ import ( ...@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware { func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService))
} }
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme) // 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
...@@ -53,6 +53,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -53,6 +53,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 如果所有header都没有API key // 如果所有header都没有API key
if apiKeyString == "" { if apiKeyString == "" {
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter") AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return return
} }
...@@ -60,35 +61,40 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -60,35 +61,40 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key") AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return return
} }
// 检查API key是否激活 // 检查API key是否激活
if !apiKey.IsActive() { if !apiKey.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return return
} }
// 检查关联的用户 // 检查关联的用户
if apiKey.User == nil { if apiKey.User == nil {
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return return
} }
// 检查用户状态 // 检查用户状态
if !apiKey.User.IsActive() { if !apiKey.User.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active") AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return return
} }
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -109,12 +115,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -109,12 +115,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
apiKey.Group.ID, apiKey.Group.ID,
) )
if err != nil { if err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return return
} }
// 验证订阅状态(是否过期、暂停等) // 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return return
} }
...@@ -131,6 +139,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -131,6 +139,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 预检查用量限制(使用0作为额外费用进行预检查) // 预检查用量限制(使用0作为额外费用进行预检查)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return return
} }
...@@ -140,13 +149,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -140,13 +149,14 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} else { } else {
// 余额模式:检查用户余额 // 余额模式:检查用户余额
if apiKey.User.Balance <= 0 { if apiKey.User.Balance <= 0 {
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return return
} }
} }
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -157,13 +167,66 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti ...@@ -157,13 +167,66 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
} }
} }
// GetApiKeyFromContext 从上下文中获取API key func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) {
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) { if opsService == nil || c == nil {
value, exists := c.Get(string(ContextKeyApiKey)) return
}
errType := "authentication_error"
phase := "auth"
severity := "P3"
switch status {
case 403:
errType = "billing_error"
phase = "billing"
case 429:
errType = "rate_limit_error"
phase = "billing"
severity = "P2"
case 500:
errType = "api_error"
phase = "internal"
severity = "P1"
}
logEntry := &service.OpsErrorLog{
Phase: phase,
Type: errType,
Severity: severity,
StatusCode: status,
Message: message,
ClientIP: c.ClientIP(),
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
}
if apiKey != nil {
logEntry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
logEntry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
logEntry.GroupID = apiKey.GroupID
}
if apiKey.Group != nil {
logEntry.Platform = apiKey.Group.Platform
}
}
enqueueOpsAuthErrorLog(opsService, logEntry)
}
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*service.ApiKey) apiKey, ok := value.(*service.APIKey)
return apiKey, ok return apiKey, ok
} }
......
...@@ -11,16 +11,16 @@ import ( ...@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. // APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
} }
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
// //
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c) apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" { if apiKeyString == "" {
...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -30,7 +30,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) { if errors.Is(err, service.ErrAPIKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key") abortWithGoogleError(c, 401, "Invalid API key")
return return
} }
...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -53,7 +53,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
// 简易模式:跳过余额和订阅检查 // 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple { if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs ...@@ -92,7 +92,7 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
} }
} }
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{ c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID, UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
......
...@@ -16,53 +16,53 @@ import ( ...@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type fakeApiKeyRepo struct { type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil { if f.getByKey == nil {
return nil, errors.New("unexpected call") return nil, errors.New("unexpected call")
} }
return f.getByKey(ctx, key) return f.getByKey(ctx, key)
} }
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
...@@ -74,8 +74,8 @@ type googleErrorResponse struct { ...@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"` } `json:"error"`
} }
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService { func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewApiKeyService( return service.NewAPIKeyService(
repo, repo,
nil, // userRepo (unused in GetByKey) nil, // userRepo (unused in GetByKey)
nil, // groupRepo nil, // groupRepo
...@@ -85,16 +85,16 @@ func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService ...@@ -85,16 +85,16 @@ func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService
) )
} }
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("should not be called") return nil, errors.New("should not be called")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -109,16 +109,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) { ...@@ -109,16 +109,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -134,16 +134,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) { ...@@ -134,16 +134,16 @@ func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, errors.New("db down") return nil, errors.New("db down")
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -159,13 +159,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) { ...@@ -159,13 +159,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
require.Equal(t, "INTERNAL", resp.Error.Status) require.Equal(t, "INTERNAL", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusDisabled, Status: service.StatusDisabled,
...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
...@@ -192,13 +192,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) { ...@@ -192,13 +192,13 @@ func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status) require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
} }
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
r := gin.New() r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.ApiKey{ return &service.APIKey{
ID: 1, ID: 1,
Key: key, Key: key,
Status: service.StatusActive, Status: service.StatusActive,
...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { ...@@ -210,7 +210,7 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil }, nil
}, },
}) })
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{})) r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
......
...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10, Balance: 10,
Concurrency: 3, Concurrency: 3,
} }
apiKey := &service.ApiKey{ apiKey := &service.APIKey{
ID: 100, ID: 100,
UserID: user.ID, UserID: user.ID,
Key: "test-key", Key: "test-key",
...@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
} }
apiKey.GroupID = &group.ID apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{ apiKeyRepo := &stubAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key { if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound return nil, service.ErrAPIKeyNotFound
} }
clone := *apiKey clone := *apiKey
return &clone, nil return &clone, nil
...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple} cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard} cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now() now := time.Now()
sub := &service.UserSubscription{ sub := &service.UserSubscription{
...@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { ...@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}) })
} }
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil)))
router.GET("/t", func(c *gin.Context) { router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
}) })
return router return router
} }
type stubApiKeyRepo struct { type stubAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error) getByKey func(ctx context.Context, key string) (*service.APIKey, error)
} }
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if r.getByKey != nil { if r.getByKey != nil {
return r.getByKey(ctx, key) return r.getByKey(ctx, key)
} }
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented") return nil, nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
......
...@@ -2,11 +2,14 @@ package middleware ...@@ -2,11 +2,14 @@ package middleware
import ( import (
"log" "log"
"regexp"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
// Logger 请求日志中间件 // Logger 请求日志中间件
func Logger() gin.HandlerFunc { func Logger() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
...@@ -26,7 +29,7 @@ func Logger() gin.HandlerFunc { ...@@ -26,7 +29,7 @@ func Logger() gin.HandlerFunc {
method := c.Request.Method method := c.Request.Method
// 请求路径 // 请求路径
path := c.Request.URL.Path path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***")
// 状态码 // 状态码
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
......
// Package middleware provides HTTP middleware components for authentication,
// authorization, logging, error recovery, and request processing.
package middleware package middleware
import ( import (
...@@ -15,8 +17,8 @@ const ( ...@@ -15,8 +17,8 @@ const (
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色(string) // ContextKeyUserRole 当前用户角色(string)
ContextKeyUserRole ContextKey = "user_role" ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyAPIKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyAPIKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription" ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由) // ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
......
package middleware
import (
"context"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
opsAuthErrorLogWorkerCount = 10
opsAuthErrorLogQueueSize = 256
opsAuthErrorLogTimeout = 2 * time.Second
)
type opsAuthErrorLogJob struct {
ops *service.OpsService
entry *service.OpsErrorLog
}
var (
opsAuthErrorLogOnce sync.Once
opsAuthErrorLogQueue chan opsAuthErrorLogJob
)
func startOpsAuthErrorLogWorkers() {
opsAuthErrorLogQueue = make(chan opsAuthErrorLogJob, opsAuthErrorLogQueueSize)
for i := 0; i < opsAuthErrorLogWorkerCount; i++ {
go func() {
for job := range opsAuthErrorLogQueue {
if job.ops == nil || job.entry == nil {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), opsAuthErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry)
cancel()
}
}()
}
}
func enqueueOpsAuthErrorLog(ops *service.OpsService, entry *service.OpsErrorLog) {
if ops == nil || entry == nil {
return
}
opsAuthErrorLogOnce.Do(startOpsAuthErrorLogWorkers)
select {
case opsAuthErrorLogQueue <- opsAuthErrorLogJob{ops: ops, entry: entry}:
default:
// Queue is full; drop to avoid blocking request handling.
}
}
...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc ...@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型 // AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型 // APIKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc type APIKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入 // ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware, NewJWTAuthMiddleware,
NewAdminAuthMiddleware, NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware, NewAPIKeyAuthMiddleware,
) )
...@@ -17,8 +17,8 @@ func SetupRouter( ...@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers, handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) *gin.Engine { ) *gin.Engine {
...@@ -43,8 +43,8 @@ func registerRoutes( ...@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware, jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware, adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
......
...@@ -19,6 +19,9 @@ func RegisterAdminRoutes( ...@@ -19,6 +19,9 @@ func RegisterAdminRoutes(
// 仪表盘 // 仪表盘
registerDashboardRoutes(admin, h) registerDashboardRoutes(admin, h)
// 运维监控
registerOpsRoutes(admin, h)
// 用户管理 // 用户管理
registerUserManagementRoutes(admin, h) registerUserManagementRoutes(admin, h)
...@@ -67,10 +70,35 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -67,10 +70,35 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
}
}
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops := admin.Group("/ops")
{
ops.GET("/metrics", h.Admin.Ops.GetMetrics)
ops.GET("/metrics/history", h.Admin.Ops.ListMetricsHistory)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/error-logs", h.Admin.Ops.ListErrorLogs)
// Dashboard routes
dashboard := ops.Group("/dashboard")
{
dashboard.GET("/overview", h.Admin.Ops.GetDashboardOverview)
dashboard.GET("/providers", h.Admin.Ops.GetProviderHealth)
dashboard.GET("/latency-histogram", h.Admin.Ops.GetLatencyHistogram)
dashboard.GET("/errors/distribution", h.Admin.Ops.GetErrorDistribution)
}
// WebSocket routes
ws := ops.Group("/ws")
{
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
} }
} }
...@@ -203,12 +231,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -203,12 +231,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{ {
adminSettings.GET("", h.Admin.Setting.GetSettings) adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings) adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理 // Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey) adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
} }
} }
...@@ -248,7 +276,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -248,7 +276,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List) usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
} }
} }
......
// Package routes 提供 HTTP 路由注册和处理函数
package routes package routes
import ( import (
......
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes( func RegisterGatewayRoutes(
r *gin.Engine, r *gin.Engine,
h *handler.Handlers, h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.ApiKeyService, apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes( ...@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit) gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
...@@ -62,7 +62,7 @@ func RegisterGatewayRoutes( ...@@ -62,7 +62,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
......
...@@ -50,7 +50,7 @@ func RegisterUserRoutes( ...@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats) usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend) usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels) usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
} }
// 卡密兑换 // 卡密兑换
......
...@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string { ...@@ -206,7 +206,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey { if a.Type != AccountTypeAPIKey {
return "" return ""
} }
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
...@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string { ...@@ -229,7 +229,7 @@ func (a *Account) GetExtraString(key string) string {
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false
} }
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
...@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool { ...@@ -300,15 +300,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth return a.IsOpenAI() && a.Type == AccountTypeOAuth
} }
func (a *Account) IsOpenAIApiKey() bool { func (a *Account) IsOpenAIAPIKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey return a.IsOpenAI() && a.Type == AccountTypeAPIKey
} }
func (a *Account) GetOpenAIBaseURL() string { func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() { if !a.IsOpenAI() {
return "" return ""
} }
if a.Type == AccountTypeApiKey { if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url") baseURL := a.GetCredential("base_url")
if baseURL != "" { if baseURL != "" {
return baseURL return baseURL
...@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string { ...@@ -338,8 +338,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token") return a.GetCredential("id_token")
} }
func (a *Account) GetOpenAIApiKey() string { func (a *Account) GetOpenAIAPIKey() string {
if !a.IsOpenAIApiKey() { if !a.IsOpenAIAPIKey() {
return "" return ""
} }
return a.GetCredential("api_key") return a.GetCredential("api_key")
......
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service package service
import ( import (
......
...@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID() chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" { } else if account.Type == "apikey" {
// API Key - use Platform API // API Key - use Platform API
authToken = account.GetOpenAIApiKey() authToken = account.GetOpenAIAPIKey()
if authToken == "" { if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available") return s.sendErrorAndEnd(c, "No API key available")
} }
...@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
...@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error var err error
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
......
...@@ -17,11 +17,11 @@ type UsageLogRepository interface { ...@@ -17,11 +17,11 @@ type UsageLogRepository interface {
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
...@@ -32,10 +32,10 @@ type UsageLogRepository interface { ...@@ -32,10 +32,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats // User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
...@@ -51,7 +51,7 @@ type UsageLogRepository interface { ...@@ -51,7 +51,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment