"backend/internal/vscode:/vscode.git/clone" did not exist on "d6f8ac0226a16e9e8f96bb0383a5626d88f9f081"
Unverified Commit 45bd9ac7 authored by IanShaw's avatar IanShaw Committed by GitHub
Browse files

运维监控系统安全加固和功能优化 (#21)

* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
parent 7fdc2b2d
...@@ -19,7 +19,7 @@ type AdminService interface { ...@@ -19,7 +19,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
...@@ -30,7 +30,7 @@ type AdminService interface { ...@@ -30,7 +30,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
...@@ -65,7 +65,7 @@ type AdminService interface { ...@@ -65,7 +65,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
} }
// Input types for admin operations // CreateUserInput represents the input for creating a new user
type CreateUserInput struct { type CreateUserInput struct {
Email string Email string
Password string Password string
...@@ -220,7 +220,7 @@ type adminServiceImpl struct { ...@@ -220,7 +220,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository groupRepo GroupRepository
accountRepo AccountRepository accountRepo AccountRepository
proxyRepo ProxyRepository proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
...@@ -232,7 +232,7 @@ func NewAdminService( ...@@ -232,7 +232,7 @@ func NewAdminService(
groupRepo GroupRepository, groupRepo GroupRepository,
accountRepo AccountRepository, accountRepo AccountRepository,
proxyRepo ProxyRepository, proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
...@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, ...@@ -430,7 +430,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
...@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { ...@@ -583,7 +583,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
......
...@@ -2,7 +2,7 @@ package service ...@@ -2,7 +2,7 @@ package service
import "time" import "time"
type ApiKey struct { type APIKey struct {
ID int64 ID int64
UserID int64 UserID int64
Key string Key string
...@@ -15,6 +15,6 @@ type ApiKey struct { ...@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group Group *Group
} }
func (k *ApiKey) IsActive() bool { func (k *APIKey) IsActive() bool {
return k.Status == StatusActive return k.Status == StatusActive
} }
...@@ -14,39 +14,39 @@ import ( ...@@ -14,39 +14,39 @@ import (
) )
var ( var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
apiKeyMaxErrorsPerHour = 20 apiKeyMaxErrorsPerHour = 20
) )
type ApiKeyRepository interface { type APIKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error) GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error) GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *ApiKey) error Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
// ApiKeyCache defines cache operations for API key service // APIKeyCache defines cache operations for API key service
type ApiKeyCache interface { type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error DeleteCreateAttemptCount(ctx context.Context, userID int64) error
...@@ -55,40 +55,40 @@ type ApiKeyCache interface { ...@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
} }
// CreateApiKeyRequest 创建API Key请求 // CreateAPIKeyRequest 创建API Key请求
type CreateApiKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
} }
// UpdateApiKeyRequest 更新API Key请求 // UpdateAPIKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct { type UpdateAPIKeyRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status *string `json:"status"` Status *string `json:"status"`
} }
// ApiKeyService API Key服务 // APIKeyService API Key服务
type ApiKeyService struct { type APIKeyService struct {
apiKeyRepo ApiKeyRepository apiKeyRepo APIKeyRepository
userRepo UserRepository userRepo UserRepository
groupRepo GroupRepository groupRepo GroupRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache ApiKeyCache cache APIKeyCache
cfg *config.Config cfg *config.Config
} }
// NewApiKeyService 创建API Key服务实例 // NewAPIKeyService 创建API Key服务实例
func NewApiKeyService( func NewAPIKeyService(
apiKeyRepo ApiKeyRepository, apiKeyRepo APIKeyRepository,
userRepo UserRepository, userRepo UserRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache ApiKeyCache, cache APIKeyCache,
cfg *config.Config, cfg *config.Config,
) *ApiKeyService { ) *APIKeyService {
return &ApiKeyService{ return &APIKeyService{
apiKeyRepo: apiKeyRepo, apiKeyRepo: apiKeyRepo,
userRepo: userRepo, userRepo: userRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
...@@ -99,7 +99,7 @@ func NewApiKeyService( ...@@ -99,7 +99,7 @@ func NewApiKeyService(
} }
// GenerateKey 生成随机API Key // GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) { func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据 // 生成32字节随机数据
bytes := make([]byte, 32) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// 转换为十六进制字符串并添加前缀 // 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" { if prefix == "" {
prefix = "sk-" prefix = "sk-"
} }
...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) { ...@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
} }
// ValidateCustomKey 验证自定义API Key格式 // ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error { func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度 // 检查长度
if len(key) < 16 { if len(key) < 16 {
return ErrApiKeyTooShort return ErrAPIKeyTooShort
} }
// 检查字符:只允许字母、数字、下划线、连字符 // 检查字符:只允许字母、数字、下划线、连字符
...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { ...@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' { c == '_' || c == '-' {
continue continue
} }
return ErrApiKeyInvalidChars return ErrAPIKeyInvalidChars
} }
return nil return nil
} }
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限 // checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error { func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
} }
...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) ...@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
} }
if count >= apiKeyMaxErrorsPerHour { if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited return ErrAPIKeyRateLimited
} }
return nil return nil
} }
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数 // incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) { func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil { if s.cache == nil {
return return
} }
...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in ...@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group ...@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
} }
// Create 创建API Key // Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key // 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" { if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流(仅对自定义key进行限流) // 检查限流(仅对自定义key进行限流)
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil { if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err return nil, err
} }
...@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -219,9 +219,9 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
return nil, fmt.Errorf("check key exists: %w", err) return nil, fmt.Errorf("check key exists: %w", err)
} }
if exists { if exists {
// Key已存在增加错误计数 // Key已存在,增加错误计数
s.incrementApiKeyErrorCount(ctx, userID) s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists return nil, ErrAPIKeyExists
} }
key = *req.CustomKey key = *req.CustomKey
...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &ApiKey{ apiKey := &APIKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK ...@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio ...@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil return keys, pagination, nil
} }
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return []int64{}, nil return []int64{}, nil
} }
...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe ...@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) ...@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
} }
// GetByKey 根据Key字符串获取API Key(用于认证) // GetByKey 根据Key字符串获取API Key(用于认证)
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro ...@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
} }
// Update 更新API Key // Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) { func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req ...@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key // Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, // 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能 // 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象 // 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil { if err != nil {
...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro ...@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效(用于认证中间件) // ValidateKey 验证API Key是否有效(用于认证中间件)
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) { func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, * ...@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
} }
// IncrementUsage 增加API Key使用次数(可选:用于统计) // IncrementUsage 增加API Key使用次数(可选:用于统计)
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器 // 使用Redis计数器
if s.cache != nil { if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02")) cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { ...@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ ...@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc ...@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)
} }
......
//go:build unit //go:build unit
// API Key 服务删除方法的单元测试 // API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为, // 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理 // 包括权限验证、缓存清理和错误处理
package service package service
...@@ -16,12 +16,12 @@ import ( ...@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。 // apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。 // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
// //
// 设计说明: // 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID // - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound) // - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
// - deleteErr: 模拟 Delete 返回的错误 // - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
type apiKeyRepoStub struct { type apiKeyRepoStub struct {
...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct { ...@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call") panic("unexpected Create call")
} }
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call") panic("unexpected GetByID call")
} }
...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error ...@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr return s.ownerID, s.ownerErr
} }
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) { func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call") panic("unexpected GetByKey call")
} }
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error { func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call") panic("unexpected Update call")
} }
...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { ...@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法 // 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call") panic("unexpected ListByUserID call")
} }
...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err ...@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call") panic("unexpected ExistsByKey call")
} }
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call") panic("unexpected ListByGroupID call")
} }
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchApiKeys call") panic("unexpected SearchAPIKeys call")
} }
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int ...@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call") panic("unexpected CountByGroupID call")
} }
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。 // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。
// //
// 设计说明: // 设计说明:
...@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string ...@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil return nil
} }
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回所有者 ID 为 1 // - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2(不匹配) // - 调用者 userID 为 2(不匹配)
// - 返回 ErrInsufficientPerms 错误 // - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1} repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2 err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms) require.ErrorIs(t, err, ErrInsufficientPerms)
...@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { ...@@ -150,17 +150,17 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除 require.Empty(t, cache.invalidated) // 验证缓存未被清除
} }
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为: // 预期行为:
// - GetOwnerID 返回所有者 ID 为 7 // - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7(匹配) // - 调用者 userID 为 7(匹配)
// - Delete 成功执行 // - Delete 成功执行
// - 缓存被正确清除(使用 ownerID) // - 缓存被正确清除(使用 ownerID)
// - 返回 nil 错误 // - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) { func TestAPIKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7} repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err) require.NoError(t, err)
...@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) { ...@@ -168,37 +168,37 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
} }
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为: // 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误 // - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装) // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用 // - Delete 方法不被调用
// - 缓存不被清除 // - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) { func TestAPIKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound} repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1) err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound) require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs) require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated) require.Empty(t, cache.invalidated)
} }
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为: // 预期行为:
// - GetOwnerID 返回正确的所有者 ID // - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过 // - 所有权验证通过
// - 缓存被清除(在删除之前) // - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误 // - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息 // - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) { func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{ repo := &apiKeyRepoStub{
ownerID: 3, ownerID: 3,
deleteErr: errors.New("delete failed"), deleteErr: errors.New("delete failed"),
} }
cache := &apiKeyCacheStub{} cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache} svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3 err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err) require.Error(t, err)
......
...@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID ...@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求 // CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查 // 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple { if s.cfg.RunMode == config.RunModeSimple {
return nil return nil
......
...@@ -32,6 +32,7 @@ type ConcurrencyCache interface { ...@@ -32,6 +32,7 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL) // 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
GetTotalWaitCount(ctx context.Context) (int, error)
// 批量负载查询(只读) // 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
...@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 ...@@ -200,6 +201,14 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
} }
} }
// GetTotalWaitCount returns the total wait queue depth across users.
func (s *ConcurrencyService) GetTotalWaitCount(ctx context.Context) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetTotalWaitCount(ctx)
}
// IncrementAccountWaitCount increments the wait queue counter for an account. // IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil { if s.cache == nil {
......
...@@ -82,7 +82,7 @@ type crsExportResponse struct { ...@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"` OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"` OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"` GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"` GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
} `json:"data"` } `json:"data"`
} }
...@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
...@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini, Platform: PlatformGemini,
Type: AccountTypeApiKey, Type: AccountTypeAPIKey,
Credentials: credentials, Credentials: credentials,
Extra: extra, Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
...@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
......
...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi ...@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil return stats, nil
} }
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit) trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err) return nil, fmt.Errorf("get api key usage trend: %w", err)
} }
...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [ ...@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil return stats, nil
} }
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs) stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err) return nil, fmt.Errorf("get batch api key usage stats: %w", err)
} }
......
...@@ -28,7 +28,7 @@ const ( ...@@ -28,7 +28,7 @@ const (
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
AccountTypeApiKey = "apikey" // API Key类型账号 AccountTypeAPIKey = "apikey" // API Key类型账号
) )
// Redeem type constants // Redeem type constants
...@@ -64,13 +64,13 @@ const ( ...@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置 // 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口 SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名 SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
SettingKeySmtpFrom = "smtp_from" // 发件人地址 SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置 // Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
...@@ -81,20 +81,20 @@ const ( ...@@ -81,20 +81,20 @@ const (
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接 SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key // 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON) // Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
) )
// Admin API Key prefix (distinct from user "sk-" keys) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-" const AdminAPIKeyPrefix = "admin-"
...@@ -40,8 +40,8 @@ const ( ...@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5 maxVerifyCodeAttempts = 5
) )
// SmtpConfig SMTP配置 // SMTPConfig SMTP配置
type SmtpConfig struct { type SMTPConfig struct {
Host string Host string
Port int Port int
Username string Username string
...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ ...@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
} }
} }
// GetSmtpConfig 从数据库获取SMTP配置 // GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{ keys := []string{
SettingKeySmtpHost, SettingKeySMTPHost,
SettingKeySmtpPort, SettingKeySMTPPort,
SettingKeySmtpUsername, SettingKeySMTPUsername,
SettingKeySmtpPassword, SettingKeySMTPPassword,
SettingKeySmtpFrom, SettingKeySMTPFrom,
SettingKeySmtpFromName, SettingKeySMTPFromName,
SettingKeySmtpUseTLS, SettingKeySMTPUseTLS,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { ...@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[SettingKeySmtpHost] host := settings[SettingKeySMTPHost]
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
port := 587 // 默认端口 port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" { if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil { if p, err := strconv.Atoi(portStr); err == nil {
port = p port = p
} }
} }
useTLS := settings[SettingKeySmtpUseTLS] == "true" useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{ return &SMTPConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[SettingKeySmtpUsername], Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySmtpPassword], Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySmtpFrom], From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySmtpFromName], FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }
// SendEmail 发送邮件(使用数据库中保存的配置) // SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error { func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx) config, err := s.GetSMTPConfig(ctx)
if err != nil { if err != nil {
return err return err
} }
...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) ...@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
} }
// SendEmailWithConfig 使用指定配置发送邮件 // SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error { func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From from := config.From
if config.FromName != "" { if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From) from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string { ...@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code) `, siteName, code)
} }
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接 // TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port) addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS { if config.UseTLS {
......
...@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference( ...@@ -276,7 +276,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
...@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ...@@ -617,7 +617,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
......
...@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( ...@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken: case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow // Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account) return s.getOAuthToken(ctx, account)
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
...@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射(仅对apikey类型账号) // 应用模型映射(仅对apikey类型账号)
originalModel := reqModel originalModel := reqModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
// 替换请求体中的模型名 // 替换请求体中的模型名
...@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1110,7 +1110,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" targetURL = baseURL + "/v1/messages"
} }
...@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1178,10 +1178,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理) // 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" { if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }
...@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool { ...@@ -1248,12 +1248,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false return false
} }
func defaultApiKeyBetaHeader(body []byte) string { func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String() modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader return claude.APIKeyHaikuBetaHeader
} }
return claude.ApiKeyBetaHeader return claude.APIKeyBetaHeader
} }
func truncateForLog(b []byte, maxBytes int) string { func truncateForLog(b []byte, maxBytes int) string {
...@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo ...@@ -1630,7 +1630,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ApiKey *ApiKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
...@@ -1639,7 +1639,7 @@ type RecordUsageInput struct { ...@@ -1639,7 +1639,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.ApiKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
...@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -1676,7 +1676,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
...@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1762,7 +1762,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
if reqModel != "" { if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel { if mappedModel != reqModel {
...@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1848,7 +1848,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" targetURL = baseURL + "/v1/messages/count_tokens"
} }
...@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -1910,10 +1910,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭) // API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) { if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" { if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta) req.Header.Set("anthropic-beta", beta)
} }
} }
......
...@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont ...@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999 return 999
} }
switch a.Type { switch a.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" { if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0 return 0
} }
...@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model originalModel := req.Model
mappedModel := req.Model mappedModel := req.Model
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model) mappedModel = account.GetMappedModel(req.Model)
} }
...@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
...@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -614,7 +614,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
...@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -636,7 +636,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error) var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
...@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac ...@@ -1758,7 +1758,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
} }
switch account.Type { switch account.Type {
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key")) apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" { if apiKey == "" {
return nil, errors.New("gemini api_key not configured") return nil, errors.New("gemini api_key not configured")
......
...@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr ...@@ -275,7 +275,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{ repo := &mockAccountRepoForGemini{
accounts: []Account{ accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil}, {ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
}, },
accountsByID: map[int64]*Account{}, accountsByID: map[int64]*Account{},
......
...@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string { ...@@ -251,7 +251,7 @@ func inferGoogleOneTier(storageBytes int64) string {
return TierGoogleOneUnknown return TierGoogleOneUnknown
} }
// fetchGoogleOneTier fetches Google One tier from Drive API // FetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient() driveClient := geminicli.NewDriveClient()
......
...@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco ...@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
case AccountTypeApiKey: case AccountTypeAPIKey:
apiKey := account.GetOpenAIApiKey() apiKey := account.GetOpenAIAPIKey()
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
} }
...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. ...@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth: case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API // OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL targetURL = chatgptCodexURL
case AccountTypeApiKey: case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL != "" {
...@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
ApiKey *ApiKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
...@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct { ...@@ -949,7 +949,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result result := input.Result
apiKey := input.ApiKey apiKey := input.APIKey
user := input.User user := input.User
account := input.Account account := input.Account
subscription := input.Subscription subscription := input.Subscription
...@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -991,7 +991,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
......
package service
import (
"context"
"time"
)
// ErrorLog represents an ops error log item for list queries.
//
// Field naming matches docs/API-运维监控中心2.0.md (L3 根因追踪 - 错误日志列表).
type ErrorLog struct {
ID int64 `json:"id"`
Timestamp time.Time `json:"timestamp"`
Level string `json:"level,omitempty"`
RequestID string `json:"request_id,omitempty"`
AccountID string `json:"account_id,omitempty"`
APIPath string `json:"api_path,omitempty"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
HTTPCode int `json:"http_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
DurationMs *int `json:"duration_ms,omitempty"`
RetryCount *int `json:"retry_count,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ErrorLogFilter describes optional filters and pagination for listing ops error logs.
type ErrorLogFilter struct {
StartTime *time.Time
EndTime *time.Time
ErrorCode *int
Provider string
AccountID *int64
Page int
PageSize int
}
func (f *ErrorLogFilter) normalize() (page, pageSize int) {
page = 1
pageSize = 20
if f == nil {
return page, pageSize
}
if f.Page > 0 {
page = f.Page
}
if f.PageSize > 0 {
pageSize = f.PageSize
}
if pageSize > 100 {
pageSize = 100
}
return page, pageSize
}
type ErrorLogListResponse struct {
Errors []*ErrorLog `json:"errors"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *ErrorLogFilter) (*ErrorLogListResponse, error) {
if s == nil || s.repo == nil {
return &ErrorLogListResponse{
Errors: []*ErrorLog{},
Total: 0,
Page: 1,
PageSize: 20,
}, nil
}
page, pageSize := filter.normalize()
if filter == nil {
filter = &ErrorLogFilter{}
}
filter.Page = page
filter.PageSize = pageSize
items, total, err := s.repo.ListErrorLogs(ctx, filter)
if err != nil {
return nil, err
}
if items == nil {
items = []*ErrorLog{}
}
return &ErrorLogListResponse{
Errors: items,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
type OpsAlertService struct {
opsService *OpsService
userService *UserService
emailService *EmailService
httpClient *http.Client
interval time.Duration
startOnce sync.Once
stopOnce sync.Once
stopCtx context.Context
stop context.CancelFunc
wg sync.WaitGroup
}
// opsAlertEvalInterval defines how often OpsAlertService evaluates alert rules.
//
// Production uses opsMetricsInterval. Tests may override this variable to keep
// integration tests fast without changing production defaults.
var opsAlertEvalInterval = opsMetricsInterval
func NewOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
return &OpsAlertService{
opsService: opsService,
userService: userService,
emailService: emailService,
httpClient: &http.Client{Timeout: 10 * time.Second},
interval: opsAlertEvalInterval,
}
}
// Start launches the background alert evaluation loop.
//
// Stop must be called during shutdown to ensure the goroutine exits.
func (s *OpsAlertService) Start() {
s.StartWithContext(context.Background())
}
// StartWithContext is like Start but allows the caller to provide a parent context.
// When the parent context is canceled, the service stops automatically.
func (s *OpsAlertService) StartWithContext(ctx context.Context) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
s.startOnce.Do(func() {
if s.interval <= 0 {
s.interval = opsAlertEvalInterval
}
s.stopCtx, s.stop = context.WithCancel(ctx)
s.wg.Add(1)
go s.run()
})
}
// Stop gracefully stops the background goroutine started by Start/StartWithContext.
// It is safe to call Stop multiple times.
func (s *OpsAlertService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stop != nil {
s.stop()
}
})
s.wg.Wait()
}
func (s *OpsAlertService) run() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.evaluateOnce()
for {
select {
case <-ticker.C:
s.evaluateOnce()
case <-s.stopCtx.Done():
return
}
}
}
func (s *OpsAlertService) evaluateOnce() {
ctx, cancel := context.WithTimeout(s.stopCtx, opsAlertEvaluateTimeout)
defer cancel()
s.Evaluate(ctx, time.Now())
}
func (s *OpsAlertService) Evaluate(ctx context.Context, now time.Time) {
if s == nil || s.opsService == nil {
return
}
rules, err := s.opsService.ListAlertRules(ctx)
if err != nil {
log.Printf("[OpsAlert] failed to list rules: %v", err)
return
}
if len(rules) == 0 {
return
}
maxSustainedByWindow := make(map[int]int)
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
if sustained > maxSustainedByWindow[window] {
maxSustainedByWindow[window] = sustained
}
}
metricsByWindow := make(map[int][]OpsMetrics)
for window, limit := range maxSustainedByWindow {
metrics, err := s.opsService.ListRecentSystemMetrics(ctx, window, limit)
if err != nil {
log.Printf("[OpsAlert] failed to load metrics window=%dm: %v", window, err)
continue
}
metricsByWindow[window] = metrics
}
for _, rule := range rules {
if !rule.Enabled {
continue
}
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
sustained := rule.SustainedMinutes
if sustained <= 0 {
sustained = 1
}
metrics := metricsByWindow[window]
selected, ok := selectContiguousMetrics(metrics, sustained, now)
if !ok {
continue
}
breached, latestValue, ok := evaluateRule(rule, selected)
if !ok {
continue
}
activeEvent, err := s.opsService.GetActiveAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get active event (rule=%d): %v", rule.ID, err)
continue
}
if breached {
if activeEvent != nil {
continue
}
lastEvent, err := s.opsService.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlert] failed to get latest event (rule=%d): %v", rule.ID, err)
continue
}
if lastEvent != nil && rule.CooldownMinutes > 0 {
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
if now.Sub(lastEvent.FiredAt) < cooldown {
continue
}
}
event := &OpsAlertEvent{
RuleID: rule.ID,
Severity: rule.Severity,
Status: OpsAlertStatusFiring,
Title: fmt.Sprintf("%s: %s", rule.Severity, rule.Name),
Description: buildAlertDescription(rule, latestValue),
MetricValue: latestValue,
ThresholdValue: rule.Threshold,
FiredAt: now,
CreatedAt: now,
}
if err := s.opsService.CreateAlertEvent(ctx, event); err != nil {
log.Printf("[OpsAlert] failed to create event (rule=%d): %v", rule.ID, err)
continue
}
emailSent, webhookSent := s.dispatchNotifications(ctx, rule, event)
if emailSent || webhookSent {
if err := s.opsService.UpdateAlertEventNotifications(ctx, event.ID, emailSent, webhookSent); err != nil {
log.Printf("[OpsAlert] failed to update notification flags (event=%d): %v", event.ID, err)
}
}
} else if activeEvent != nil {
resolvedAt := now
if err := s.opsService.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlert] failed to resolve event (event=%d): %v", activeEvent.ID, err)
}
}
}
}
const opsMetricsContinuityTolerance = 20 * time.Second
// selectContiguousMetrics picks the newest N metrics and verifies they are continuous.
//
// This prevents a sustained rule from triggering when metrics sampling has gaps
// (e.g. collector downtime) and avoids evaluating "stale" data.
//
// Assumptions:
// - Metrics are ordered by UpdatedAt DESC (newest first).
// - Metrics are expected to be collected at opsMetricsInterval cadence.
func selectContiguousMetrics(metrics []OpsMetrics, needed int, now time.Time) ([]OpsMetrics, bool) {
if needed <= 0 {
return nil, false
}
if len(metrics) < needed {
return nil, false
}
newest := metrics[0].UpdatedAt
if newest.IsZero() {
return nil, false
}
if now.Sub(newest) > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
selected := metrics[:needed]
for i := 0; i < len(selected)-1; i++ {
a := selected[i].UpdatedAt
b := selected[i+1].UpdatedAt
if a.IsZero() || b.IsZero() {
return nil, false
}
gap := a.Sub(b)
if gap < opsMetricsInterval-opsMetricsContinuityTolerance || gap > opsMetricsInterval+opsMetricsContinuityTolerance {
return nil, false
}
}
return selected, true
}
func evaluateRule(rule OpsAlertRule, metrics []OpsMetrics) (bool, float64, bool) {
if len(metrics) == 0 {
return false, 0, false
}
latestValue, ok := metricValue(metrics[0], rule.MetricType)
if !ok {
return false, 0, false
}
for _, metric := range metrics {
value, ok := metricValue(metric, rule.MetricType)
if !ok || !compareMetric(value, rule.Operator, rule.Threshold) {
return false, latestValue, true
}
}
return true, latestValue, true
}
func metricValue(metric OpsMetrics, metricType string) (float64, bool) {
switch metricType {
case OpsMetricSuccessRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.SuccessRate, true
case OpsMetricErrorRate:
if metric.RequestCount == 0 {
return 0, false
}
return metric.ErrorRate, true
case OpsMetricP95LatencyMs:
return float64(metric.P95LatencyMs), true
case OpsMetricP99LatencyMs:
return float64(metric.P99LatencyMs), true
case OpsMetricHTTP2Errors:
return float64(metric.HTTP2Errors), true
case OpsMetricCPUUsagePercent:
return metric.CPUUsagePercent, true
case OpsMetricMemoryUsagePercent:
return metric.MemoryUsagePercent, true
case OpsMetricQueueDepth:
return float64(metric.ConcurrencyQueueDepth), true
default:
return 0, false
}
}
func compareMetric(value float64, operator string, threshold float64) bool {
switch operator {
case ">":
return value > threshold
case ">=":
return value >= threshold
case "<":
return value < threshold
case "<=":
return value <= threshold
case "==":
return value == threshold
default:
return false
}
}
func buildAlertDescription(rule OpsAlertRule, value float64) string {
window := rule.WindowMinutes
if window <= 0 {
window = 1
}
return fmt.Sprintf("Rule %s triggered: %s %s %.2f (current %.2f) over last %dm",
rule.Name,
rule.MetricType,
rule.Operator,
rule.Threshold,
value,
window,
)
}
func (s *OpsAlertService) dispatchNotifications(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) (bool, bool) {
emailSent := false
webhookSent := false
notifyCtx, cancel := s.notificationContext(ctx)
defer cancel()
if rule.NotifyEmail {
emailSent = s.sendEmailNotification(notifyCtx, rule, event)
}
if rule.NotifyWebhook && rule.WebhookURL != "" {
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
// Fallback channel: if email is enabled but ultimately fails, try webhook even if the
// webhook toggle is off (as long as a webhook URL is configured).
if rule.NotifyEmail && !emailSent && !rule.NotifyWebhook && rule.WebhookURL != "" {
log.Printf("[OpsAlert] email failed; attempting webhook fallback (rule=%d)", rule.ID)
webhookSent = s.sendWebhookNotification(notifyCtx, rule, event)
}
return emailSent, webhookSent
}
const (
opsAlertEvaluateTimeout = 45 * time.Second
opsAlertNotificationTimeout = 30 * time.Second
opsAlertEmailMaxRetries = 3
)
var opsAlertEmailBackoff = []time.Duration{
1 * time.Second,
2 * time.Second,
4 * time.Second,
}
func (s *OpsAlertService) notificationContext(ctx context.Context) (context.Context, context.CancelFunc) {
parent := ctx
if s != nil && s.stopCtx != nil {
parent = s.stopCtx
}
if parent == nil {
parent = context.Background()
}
return context.WithTimeout(parent, opsAlertNotificationTimeout)
}
var opsAlertSleep = sleepWithContext
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
if ctx == nil {
time.Sleep(d)
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func retryWithBackoff(
ctx context.Context,
maxRetries int,
backoff []time.Duration,
fn func() error,
onError func(attempt int, total int, nextDelay time.Duration, err error),
) error {
if ctx == nil {
ctx = context.Background()
}
if maxRetries < 0 {
maxRetries = 0
}
totalAttempts := maxRetries + 1
var lastErr error
for attempt := 1; attempt <= totalAttempts; attempt++ {
if attempt > 1 {
backoffIdx := attempt - 2
if backoffIdx < len(backoff) {
if err := opsAlertSleep(ctx, backoff[backoffIdx]); err != nil {
return err
}
}
}
if err := ctx.Err(); err != nil {
return err
}
if err := fn(); err != nil {
lastErr = err
nextDelay := time.Duration(0)
if attempt < totalAttempts {
nextIdx := attempt - 1
if nextIdx < len(backoff) {
nextDelay = backoff[nextIdx]
}
}
if onError != nil {
onError(attempt, totalAttempts, nextDelay, err)
}
continue
}
return nil
}
return lastErr
}
func (s *OpsAlertService) sendEmailNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
if s.emailService == nil || s.userService == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
admin, err := s.userService.GetFirstAdmin(ctx)
if err != nil || admin == nil || admin.Email == "" {
return false
}
subject := fmt.Sprintf("[Ops Alert][%s] %s", rule.Severity, rule.Name)
body := fmt.Sprintf(
"Alert triggered: %s\n\nMetric: %s\nThreshold: %.2f\nCurrent: %.2f\nWindow: %dm\nStatus: %s\nTime: %s",
rule.Name,
rule.MetricType,
rule.Threshold,
event.MetricValue,
rule.WindowMinutes,
event.Status,
event.FiredAt.Format(time.RFC3339),
)
config, err := s.emailService.GetSMTPConfig(ctx)
if err != nil {
log.Printf("[OpsAlert] email config load failed: %v", err)
return false
}
if err := retryWithBackoff(
ctx,
opsAlertEmailMaxRetries,
opsAlertEmailBackoff,
func() error {
return s.emailService.SendEmailWithConfig(config, admin.Email, subject, body)
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt < total {
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), retrying in %s: %v", attempt, total, nextDelay, err)
return
}
log.Printf("[OpsAlert] email send failed (attempt=%d/%d), giving up: %v", attempt, total, err)
},
); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf("[OpsAlert] email send canceled: %v", err)
}
return false
}
return true
}
func (s *OpsAlertService) sendWebhookNotification(ctx context.Context, rule OpsAlertRule, event *OpsAlertEvent) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
webhookTarget, err := validateWebhookURL(ctx, rule.WebhookURL)
if err != nil {
log.Printf("[OpsAlert] invalid webhook url (rule=%d): %v", rule.ID, err)
return false
}
payload := map[string]any{
"rule_id": rule.ID,
"rule_name": rule.Name,
"severity": rule.Severity,
"status": event.Status,
"metric_type": rule.MetricType,
"metric_value": event.MetricValue,
"threshold_value": rule.Threshold,
"window_minutes": rule.WindowMinutes,
"fired_at": event.FiredAt.Format(time.RFC3339),
}
body, err := json.Marshal(payload)
if err != nil {
return false
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookTarget.URL.String(), bytes.NewReader(body))
if err != nil {
return false
}
req.Header.Set("Content-Type", "application/json")
resp, err := buildWebhookHTTPClient(s.httpClient, webhookTarget).Do(req)
if err != nil {
log.Printf("[OpsAlert] webhook send failed: %v", err)
return false
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
log.Printf("[OpsAlert] webhook returned status %d", resp.StatusCode)
return false
}
return true
}
const webhookHTTPClientTimeout = 10 * time.Second
func buildWebhookHTTPClient(base *http.Client, webhookTarget *validatedWebhookTarget) *http.Client {
var client http.Client
if base != nil {
client = *base
}
if client.Timeout <= 0 {
client.Timeout = webhookHTTPClientTimeout
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
if webhookTarget != nil {
client.Transport = buildWebhookTransport(client.Transport, webhookTarget)
}
return &client
}
var disallowedWebhookIPNets = []net.IPNet{
// "this host on this network" / unspecified.
mustParseCIDR("0.0.0.0/8"),
mustParseCIDR("127.0.0.0/8"), // loopback (includes 127.0.0.1)
mustParseCIDR("10.0.0.0/8"), // RFC1918
mustParseCIDR("192.168.0.0/16"), // RFC1918
mustParseCIDR("172.16.0.0/12"), // RFC1918 (172.16.0.0 - 172.31.255.255)
mustParseCIDR("100.64.0.0/10"), // RFC6598 (carrier-grade NAT)
mustParseCIDR("169.254.0.0/16"), // IPv4 link-local (includes 169.254.169.254 metadata IP on many clouds)
mustParseCIDR("198.18.0.0/15"), // RFC2544 benchmark testing
mustParseCIDR("224.0.0.0/4"), // IPv4 multicast
mustParseCIDR("240.0.0.0/4"), // IPv4 reserved
mustParseCIDR("::/128"), // IPv6 unspecified
mustParseCIDR("::1/128"), // IPv6 loopback
mustParseCIDR("fc00::/7"), // IPv6 unique local
mustParseCIDR("fe80::/10"), // IPv6 link-local
mustParseCIDR("ff00::/8"), // IPv6 multicast
}
func mustParseCIDR(cidr string) net.IPNet {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
return *block
}
var lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
}
type validatedWebhookTarget struct {
URL *url.URL
host string
port string
pinnedIPs []net.IP
}
var webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
}
func buildWebhookTransport(base http.RoundTripper, webhookTarget *validatedWebhookTarget) http.RoundTripper {
if webhookTarget == nil || webhookTarget.URL == nil {
return base
}
var transport *http.Transport
switch typed := base.(type) {
case *http.Transport:
if typed != nil {
transport = typed.Clone()
}
}
if transport == nil {
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok && defaultTransport != nil {
transport = defaultTransport.Clone()
} else {
transport = (&http.Transport{}).Clone()
}
}
webhookHost := webhookTarget.host
webhookPort := webhookTarget.port
pinnedIPs := append([]net.IP(nil), webhookTarget.pinnedIPs...)
transport.Proxy = nil
transport.DialTLSContext = nil
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil || host == "" || port == "" {
return nil, fmt.Errorf("webhook dial target is invalid: %q", addr)
}
canonicalHost := strings.TrimSuffix(strings.ToLower(host), ".")
if canonicalHost != webhookHost || port != webhookPort {
return nil, fmt.Errorf("webhook dial target mismatch: %q", addr)
}
var lastErr error
for _, ip := range pinnedIPs {
if isDisallowedWebhookIP(ip) {
lastErr = fmt.Errorf("webhook target resolves to a disallowed ip")
continue
}
dialAddr := net.JoinHostPort(ip.String(), port)
conn, err := webhookBaseDialContext(ctx, network, dialAddr)
if err == nil {
return conn, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = errors.New("webhook target has no resolved addresses")
}
return nil, lastErr
}
return transport
}
func validateWebhookURL(ctx context.Context, raw string) (*validatedWebhookTarget, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, errors.New("webhook url is empty")
}
// Avoid request smuggling / header injection vectors.
if strings.ContainsAny(raw, "\r\n") {
return nil, errors.New("webhook url contains invalid characters")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, errors.New("webhook url format is invalid")
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, errors.New("webhook url scheme must be https")
}
parsed.Scheme = "https"
if parsed.Host == "" || parsed.Hostname() == "" {
return nil, errors.New("webhook url must include host")
}
if parsed.User != nil {
return nil, errors.New("webhook url must not include userinfo")
}
if parsed.Port() != "" {
port, err := strconv.Atoi(parsed.Port())
if err != nil || port < 1 || port > 65535 {
return nil, errors.New("webhook url port is invalid")
}
}
host := strings.TrimSuffix(strings.ToLower(parsed.Hostname()), ".")
if host == "localhost" {
return nil, errors.New("webhook url host must not be localhost")
}
if ip := net.ParseIP(host); ip != nil {
if isDisallowedWebhookIP(ip) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: []net.IP{ip},
}, nil
}
if ctx == nil {
ctx = context.Background()
}
ips, err := lookupIPAddrs(ctx, host)
if err != nil || len(ips) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
pinned := make([]net.IP, 0, len(ips))
for _, addr := range ips {
if isDisallowedWebhookIP(addr.IP) {
return nil, errors.New("webhook url host resolves to a disallowed ip")
}
if addr.IP != nil {
pinned = append(pinned, addr.IP)
}
}
if len(pinned) == 0 {
return nil, errors.New("webhook url host cannot be resolved")
}
return &validatedWebhookTarget{
URL: parsed,
host: host,
port: portForScheme(parsed),
pinnedIPs: uniqueResolvedIPs(pinned),
}, nil
}
func isDisallowedWebhookIP(ip net.IP) bool {
if ip == nil {
return false
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
} else if ip16 := ip.To16(); ip16 != nil {
ip = ip16
} else {
return false
}
// Disallow non-public addresses even if they're not explicitly covered by the CIDR list.
// This provides defense-in-depth against SSRF targets such as link-local, multicast, and
// unspecified addresses, and ensures any "pinned" IP is still blocked at dial time.
if ip.IsUnspecified() ||
ip.IsLoopback() ||
ip.IsMulticast() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsPrivate() {
return true
}
for _, block := range disallowedWebhookIPNets {
if block.Contains(ip) {
return true
}
}
return false
}
func portForScheme(u *url.URL) string {
if u != nil && u.Port() != "" {
return u.Port()
}
return "443"
}
func uniqueResolvedIPs(ips []net.IP) []net.IP {
seen := make(map[string]struct{}, len(ips))
out := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if ip == nil {
continue
}
key := ip.String()
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, ip)
}
return out
}
//go:build integration
package service
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// This integration test protects the DI startup contract for OpsAlertService.
//
// Background:
// - OpsMetricsCollector previously called alertService.Start()/Evaluate() directly.
// - Those direct calls were removed, so OpsAlertService must now start via DI
// (ProvideOpsAlertService in wire.go) and run its own evaluation ticker.
//
// What we validate here:
// 1. When we construct via the Wire provider functions (ProvideOpsAlertService +
// ProvideOpsMetricsCollector), OpsAlertService starts automatically.
// 2. Its evaluation loop continues to tick even if OpsMetricsCollector is stopped,
// proving the alert evaluator is independent.
// 3. The evaluation path can trigger alert logic (CreateAlertEvent called).
func TestOpsAlertService_StartedViaWireProviders_RunsIndependentTicker(t *testing.T) {
oldInterval := opsAlertEvalInterval
opsAlertEvalInterval = 25 * time.Millisecond
t.Cleanup(func() { opsAlertEvalInterval = oldInterval })
repo := newFakeOpsRepository()
opsService := NewOpsService(repo, nil)
// Start via the Wire provider function (the production DI path).
alertService := ProvideOpsAlertService(opsService, nil, nil)
t.Cleanup(alertService.Stop)
// Construct via ProvideOpsMetricsCollector (wire.go). Stop immediately to ensure
// the alert ticker keeps running without the metrics collector.
collector := ProvideOpsMetricsCollector(opsService, NewConcurrencyService(nil))
collector.Stop()
// Wait for at least one evaluation (run() calls evaluateOnce immediately).
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= 1
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation loop keeps ticking after the metrics collector is stopped.
callsAfterCollectorStop := repo.listRulesCalls.Load()
require.Eventually(t, func() bool {
return repo.listRulesCalls.Load() >= callsAfterCollectorStop+2
}, 1*time.Second, 5*time.Millisecond)
// Confirm the evaluation logic actually fires an alert event at least once.
select {
case <-repo.eventCreatedCh:
// ok
case <-time.After(2 * time.Second):
t.Fatalf("expected OpsAlertService to create an alert event, but none was created (ListAlertRules calls=%d)", repo.listRulesCalls.Load())
}
}
func newFakeOpsRepository() *fakeOpsRepository {
return &fakeOpsRepository{
eventCreatedCh: make(chan struct{}),
}
}
// fakeOpsRepository is a lightweight in-memory stub of OpsRepository for integration tests.
// It avoids real DB/Redis usage and provides deterministic responses fast.
type fakeOpsRepository struct {
listRulesCalls atomic.Int64
mu sync.Mutex
activeEvent *OpsAlertEvent
latestEvent *OpsAlertEvent
nextEventID int64
eventCreatedCh chan struct{}
eventOnce sync.Once
}
func (r *fakeOpsRepository) CreateErrorLog(ctx context.Context, log *OpsErrorLog) error {
return nil
}
func (r *fakeOpsRepository) ListErrorLogsLegacy(ctx context.Context, filters OpsErrorLogFilters) ([]OpsErrorLog, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListErrorLogs(ctx context.Context, filter *ErrorLogFilter) ([]*ErrorLog, int64, error) {
return nil, 0, nil
}
func (r *fakeOpsRepository) GetLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return &OpsMetrics{WindowMinutes: 1}, sql.ErrNoRows
}
func (r *fakeOpsRepository) CreateSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
return &OpsWindowStats{}, nil
}
func (r *fakeOpsRepository) GetProviderStats(ctx context.Context, startTime, endTime time.Time) ([]*ProviderStats, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetLatencyHistogram(ctx context.Context, startTime, endTime time.Time) ([]*LatencyHistogramItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) GetErrorDistribution(ctx context.Context, startTime, endTime time.Time) ([]*ErrorDistributionItem, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListRecentSystemMetrics(ctx context.Context, windowMinutes, limit int) ([]OpsMetrics, error) {
if limit <= 0 {
limit = 1
}
now := time.Now()
metrics := make([]OpsMetrics, 0, limit)
for i := 0; i < limit; i++ {
metrics = append(metrics, OpsMetrics{
WindowMinutes: windowMinutes,
CPUUsagePercent: 99,
UpdatedAt: now.Add(-time.Duration(i) * opsMetricsInterval),
})
}
return metrics, nil
}
func (r *fakeOpsRepository) ListSystemMetricsRange(ctx context.Context, windowMinutes int, startTime, endTime time.Time, limit int) ([]OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) ListAlertRules(ctx context.Context) ([]OpsAlertRule, error) {
call := r.listRulesCalls.Add(1)
// Delay enabling rules slightly so the test can stop OpsMetricsCollector first,
// then observe the alert evaluator ticking independently.
if call < 5 {
return nil, nil
}
return []OpsAlertRule{
{
ID: 1,
Name: "cpu too high (test)",
Enabled: true,
MetricType: OpsMetricCPUUsagePercent,
Operator: ">",
Threshold: 0,
WindowMinutes: 1,
SustainedMinutes: 1,
Severity: "P1",
NotifyEmail: false,
NotifyWebhook: false,
CooldownMinutes: 0,
},
}, nil
}
func (r *fakeOpsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return nil, nil
}
if r.activeEvent.RuleID != ruleID {
return nil, nil
}
if r.activeEvent.Status != OpsAlertStatusFiring {
return nil, nil
}
clone := *r.activeEvent
return &clone, nil
}
func (r *fakeOpsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.latestEvent == nil || r.latestEvent.RuleID != ruleID {
return nil, nil
}
clone := *r.latestEvent
return &clone, nil
}
func (r *fakeOpsRepository) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) error {
if event == nil {
return nil
}
r.mu.Lock()
defer r.mu.Unlock()
r.nextEventID++
event.ID = r.nextEventID
clone := *event
r.latestEvent = &clone
if clone.Status == OpsAlertStatusFiring {
r.activeEvent = &clone
}
r.eventOnce.Do(func() { close(r.eventCreatedCh) })
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.Status = status
r.activeEvent.ResolvedAt = resolvedAt
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.Status = status
r.latestEvent.ResolvedAt = resolvedAt
}
return nil
}
func (r *fakeOpsRepository) UpdateAlertEventNotifications(ctx context.Context, eventID int64, emailSent, webhookSent bool) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent != nil && r.activeEvent.ID == eventID {
r.activeEvent.EmailSent = emailSent
r.activeEvent.WebhookSent = webhookSent
}
if r.latestEvent != nil && r.latestEvent.ID == eventID {
r.latestEvent.EmailSent = emailSent
r.latestEvent.WebhookSent = webhookSent
}
return nil
}
func (r *fakeOpsRepository) CountActiveAlerts(ctx context.Context) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.activeEvent == nil {
return 0, nil
}
return 1, nil
}
func (r *fakeOpsRepository) GetOverviewStats(ctx context.Context, startTime, endTime time.Time) (*OverviewStats, error) {
return &OverviewStats{}, nil
}
func (r *fakeOpsRepository) GetCachedLatestSystemMetric(ctx context.Context) (*OpsMetrics, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedLatestSystemMetric(ctx context.Context, metric *OpsMetrics) error {
return nil
}
func (r *fakeOpsRepository) GetCachedDashboardOverview(ctx context.Context, timeRange string) (*DashboardOverviewData, error) {
return nil, nil
}
func (r *fakeOpsRepository) SetCachedDashboardOverview(ctx context.Context, timeRange string, data *DashboardOverviewData, ttl time.Duration) error {
return nil
}
func (r *fakeOpsRepository) PingRedis(ctx context.Context) error {
return nil
}
//go:build unit || opsalert_unit
package service
import (
"context"
"errors"
"net"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSelectContiguousMetrics_Contiguous(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
{UpdatedAt: now.Add(-1 * time.Minute)},
{UpdatedAt: now.Add(-2 * time.Minute)},
}
selected, ok := selectContiguousMetrics(metrics, 3, now)
require.True(t, ok)
require.Len(t, selected, 3)
}
func TestSelectContiguousMetrics_GapFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now},
// Missing the -1m sample (gap ~=2m).
{UpdatedAt: now.Add(-2 * time.Minute)},
{UpdatedAt: now.Add(-3 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 3, now)
require.False(t, ok)
}
func TestSelectContiguousMetrics_StaleNewestFails(t *testing.T) {
now := time.Date(2026, 1, 1, 0, 10, 0, 0, time.UTC)
metrics := []OpsMetrics{
{UpdatedAt: now.Add(-10 * time.Minute)},
{UpdatedAt: now.Add(-11 * time.Minute)},
}
_, ok := selectContiguousMetrics(metrics, 2, now)
require.False(t, ok)
}
func TestMetricValue_SuccessRate_NoTrafficIsNoData(t *testing.T) {
metric := OpsMetrics{
RequestCount: 0,
SuccessRate: 0,
}
value, ok := metricValue(metric, OpsMetricSuccessRate)
require.False(t, ok)
require.Equal(t, 0.0, value)
}
func TestOpsAlertService_StopWithoutStart_NoPanic(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
require.NotPanics(t, func() { s.Stop() })
}
func TestOpsAlertService_StartStop_Graceful(t *testing.T) {
s := NewOpsAlertService(nil, nil, nil)
s.interval = 5 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.StartWithContext(ctx)
done := make(chan struct{})
go func() {
s.Stop()
close(done)
}()
select {
case <-done:
// ok
case <-time.After(1 * time.Second):
t.Fatal("Stop did not return; background goroutine likely stuck")
}
require.NotPanics(t, func() { s.Stop() })
}
func TestBuildWebhookHTTPClient_DefaultTimeout(t *testing.T) {
client := buildWebhookHTTPClient(nil, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
require.ErrorIs(t, client.CheckRedirect(nil, nil), http.ErrUseLastResponse)
base := &http.Client{}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, webhookHTTPClientTimeout, client.Timeout)
require.NotNil(t, client.CheckRedirect)
base = &http.Client{Timeout: 2 * time.Second}
client = buildWebhookHTTPClient(base, nil)
require.Equal(t, 2*time.Second, client.Timeout)
require.NotNil(t, client.CheckRedirect)
}
func TestValidateWebhookURL_RequiresHTTPS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "http://example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_InvalidFormatRejected(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://[::1")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsUserinfo(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://user:pass@example.com/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLocalhost(t *testing.T) {
_, err := validateWebhookURL(context.Background(), "https://localhost/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsPrivateIPLiteral(t *testing.T) {
cases := []string{
"https://0.0.0.0/webhook",
"https://127.0.0.1/webhook",
"https://10.0.0.1/webhook",
"https://192.168.1.2/webhook",
"https://172.16.0.1/webhook",
"https://172.31.255.255/webhook",
"https://100.64.0.1/webhook",
"https://169.254.169.254/webhook",
"https://198.18.0.1/webhook",
"https://224.0.0.1/webhook",
"https://240.0.0.1/webhook",
"https://[::]/webhook",
"https://[::1]/webhook",
"https://[ff02::1]/webhook",
}
for _, tc := range cases {
t.Run(tc, func(t *testing.T) {
_, err := validateWebhookURL(context.Background(), tc)
require.Error(t, err)
})
}
}
func TestValidateWebhookURL_RejectsPrivateIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "internal.example", host)
return []net.IPAddr{{IP: net.ParseIP("10.0.0.2")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://internal.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_RejectsLinkLocalIPViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "metadata.example", host)
return []net.IPAddr{{IP: net.ParseIP("169.254.169.254")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://metadata.example/webhook")
require.Error(t, err)
}
func TestValidateWebhookURL_AllowsPublicHostViaDNS(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com:443/webhook")
require.NoError(t, err)
require.Equal(t, "https", target.URL.Scheme)
require.Equal(t, "example.com", target.URL.Hostname())
require.Equal(t, "443", target.URL.Port())
}
func TestValidateWebhookURL_RejectsInvalidPort(t *testing.T) {
oldLookup := lookupIPAddrs
t.Cleanup(func() { lookupIPAddrs = oldLookup })
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
_, err := validateWebhookURL(context.Background(), "https://example.com:99999/webhook")
require.Error(t, err)
}
func TestWebhookTransport_UsesPinnedIP_NoDNSRebinding(t *testing.T) {
oldLookup := lookupIPAddrs
oldDial := webhookBaseDialContext
t.Cleanup(func() {
lookupIPAddrs = oldLookup
webhookBaseDialContext = oldDial
})
lookupCalls := 0
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
require.Equal(t, "example.com", host)
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
target, err := validateWebhookURL(context.Background(), "https://example.com/webhook")
require.NoError(t, err)
require.Equal(t, 1, lookupCalls)
lookupIPAddrs = func(ctx context.Context, host string) ([]net.IPAddr, error) {
lookupCalls++
return []net.IPAddr{{IP: net.ParseIP("10.0.0.1")}}, nil
}
var dialAddrs []string
webhookBaseDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialAddrs = append(dialAddrs, addr)
return nil, errors.New("dial blocked in test")
}
client := buildWebhookHTTPClient(nil, target)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
_, err = transport.DialContext(context.Background(), "tcp", "example.com:443")
require.Error(t, err)
require.Equal(t, []string{"93.184.216.34:443"}, dialAddrs)
require.Equal(t, 1, lookupCalls, "dial path must not re-resolve DNS")
}
func TestRetryWithBackoff_SucceedsAfterRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return nil
}
attempts := 0
err := retryWithBackoff(
context.Background(),
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
if attempts <= 3 {
return errors.New("send failed")
}
return nil
},
nil,
)
require.NoError(t, err)
require.Equal(t, 4, attempts)
require.Equal(t, []time.Duration{time.Second, 2 * time.Second, 4 * time.Second}, slept)
}
func TestRetryWithBackoff_ContextCanceledStopsRetries(t *testing.T) {
oldSleep := opsAlertSleep
t.Cleanup(func() { opsAlertSleep = oldSleep })
var slept []time.Duration
opsAlertSleep = func(ctx context.Context, d time.Duration) error {
slept = append(slept, d)
return ctx.Err()
}
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
err := retryWithBackoff(
ctx,
3,
[]time.Duration{time.Second, 2 * time.Second, 4 * time.Second},
func() error {
attempts++
return errors.New("send failed")
},
func(attempt int, total int, nextDelay time.Duration, err error) {
if attempt == 1 {
cancel()
}
},
)
require.ErrorIs(t, err, context.Canceled)
require.Equal(t, 1, attempts)
require.Equal(t, []time.Duration{time.Second}, slept)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment