Unverified Commit ac114738 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1850 from touwaeriol/feat/channel-insights

feat(monitor): channel monitor with available channels & feature flags
parents 0a80ec80 09fd83ab
......@@ -70,5 +70,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
})
}
......@@ -34,6 +34,8 @@ func ProvideAdminHandlers(
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
channelMonitorHandler *admin.ChannelMonitorHandler,
channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
) *AdminHandlers {
return &AdminHandlers{
......@@ -62,6 +64,8 @@ func ProvideAdminHandlers(
APIKey: apiKeyHandler,
ScheduledTest: scheduledTestHandler,
Channel: channelHandler,
ChannelMonitor: channelMonitorHandler,
ChannelMonitorTemplate: channelMonitorTemplateHandler,
Payment: paymentHandler,
}
}
......@@ -85,6 +89,7 @@ func ProvideHandlers(
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
channelMonitorUserHandler *ChannelMonitorUserHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
......@@ -92,6 +97,7 @@ func ProvideHandlers(
totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
paymentWebhookHandler *PaymentWebhookHandler,
availableChannelHandler *AvailableChannelHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers {
......@@ -103,6 +109,7 @@ func ProvideHandlers(
Redeem: redeemHandler,
Subscription: subscriptionHandler,
Announcement: announcementHandler,
ChannelMonitor: channelMonitorUserHandler,
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
......@@ -110,6 +117,7 @@ func ProvideHandlers(
Totp: totpHandler,
Payment: paymentHandler,
PaymentWebhook: paymentWebhookHandler,
AvailableChannel: availableChannelHandler,
}
}
......@@ -123,12 +131,14 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
NewChannelMonitorUserHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
NewPaymentHandler,
NewPaymentWebhookHandler,
NewAvailableChannelHandler,
// Admin handlers
admin.NewDashboardHandler,
......@@ -156,6 +166,8 @@ var ProviderSet = wire.NewSet(
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
admin.NewChannelHandler,
admin.NewChannelMonitorHandler,
admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
// AdminHandlers and Handlers constructors
......
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// channelMonitorRepository 实现 service.ChannelMonitorRepository。
//
// 选型说明:
// - CRUD 走 ent,复用项目的事务上下文支持
// - 聚合查询(latest per model / availability)走原生 SQL,避免 ent 在 GROUP BY 上
// 的样板代码,并保证索引能被命中
type channelMonitorRepository struct {
client *dbent.Client
db *sql.DB
}
// NewChannelMonitorRepository 创建仓储实例。
func NewChannelMonitorRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRepository {
return &channelMonitorRepository{client: client, db: db}
}
// ---------- CRUD ----------
func (r *channelMonitorRepository) Create(ctx context.Context, m *service.ChannelMonitor) error {
client := clientFromContext(ctx, r.client)
builder := client.ChannelMonitor.Create().
SetName(m.Name).
SetProvider(channelmonitor.Provider(m.Provider)).
SetEndpoint(m.Endpoint).
SetAPIKeyEncrypted(m.APIKey). // 调用方传入的已是密文
SetPrimaryModel(m.PrimaryModel).
SetExtraModels(emptySliceIfNil(m.ExtraModels)).
SetGroupName(m.GroupName).
SetEnabled(m.Enabled).
SetIntervalSeconds(m.IntervalSeconds).
SetCreatedBy(m.CreatedBy).
SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
if m.TemplateID != nil {
builder = builder.SetTemplateID(*m.TemplateID)
}
if m.BodyOverride != nil {
builder = builder.SetBodyOverride(m.BodyOverride)
}
created, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
}
m.ID = created.ID
m.CreatedAt = created.CreatedAt
m.UpdatedAt = created.UpdatedAt
return nil
}
func (r *channelMonitorRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitor, error) {
row, err := r.client.ChannelMonitor.Query().
Where(channelmonitor.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
}
return entToServiceMonitor(row), nil
}
func (r *channelMonitorRepository) Update(ctx context.Context, m *service.ChannelMonitor) error {
client := clientFromContext(ctx, r.client)
updater := client.ChannelMonitor.UpdateOneID(m.ID).
SetName(m.Name).
SetProvider(channelmonitor.Provider(m.Provider)).
SetEndpoint(m.Endpoint).
SetAPIKeyEncrypted(m.APIKey).
SetPrimaryModel(m.PrimaryModel).
SetExtraModels(emptySliceIfNil(m.ExtraModels)).
SetGroupName(m.GroupName).
SetEnabled(m.Enabled).
SetIntervalSeconds(m.IntervalSeconds).
SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
if m.TemplateID != nil {
updater = updater.SetTemplateID(*m.TemplateID)
} else {
updater = updater.ClearTemplateID()
}
if m.BodyOverride != nil {
updater = updater.SetBodyOverride(m.BodyOverride)
} else {
updater = updater.ClearBodyOverride()
}
updated, err := updater.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
}
m.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *channelMonitorRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
if err := client.ChannelMonitor.DeleteOneID(id).Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
}
return nil
}
func (r *channelMonitorRepository) List(ctx context.Context, params service.ChannelMonitorListParams) ([]*service.ChannelMonitor, int64, error) {
q := r.client.ChannelMonitor.Query()
if params.Provider != "" {
q = q.Where(channelmonitor.ProviderEQ(channelmonitor.Provider(params.Provider)))
}
if params.Enabled != nil {
q = q.Where(channelmonitor.EnabledEQ(*params.Enabled))
}
if s := strings.TrimSpace(params.Search); s != "" {
q = q.Where(channelmonitor.Or(
channelmonitor.NameContainsFold(s),
channelmonitor.GroupNameContainsFold(s),
channelmonitor.PrimaryModelContainsFold(s),
))
}
total, err := q.Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count monitors: %w", err)
}
pageSize := params.PageSize
if pageSize <= 0 {
pageSize = 20
}
page := params.Page
if page <= 0 {
page = 1
}
rows, err := q.
Order(dbent.Desc(channelmonitor.FieldID)).
Offset((page - 1) * pageSize).
Limit(pageSize).
All(ctx)
if err != nil {
return nil, 0, fmt.Errorf("list monitors: %w", err)
}
out := make([]*service.ChannelMonitor, 0, len(rows))
for _, row := range rows {
out = append(out, entToServiceMonitor(row))
}
return out, int64(total), nil
}
// ---------- 调度器辅助 ----------
func (r *channelMonitorRepository) ListEnabled(ctx context.Context) ([]*service.ChannelMonitor, error) {
rows, err := r.client.ChannelMonitor.Query().
Where(channelmonitor.EnabledEQ(true)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list enabled monitors: %w", err)
}
out := make([]*service.ChannelMonitor, 0, len(rows))
for _, row := range rows {
out = append(out, entToServiceMonitor(row))
}
return out, nil
}
func (r *channelMonitorRepository) MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error {
client := clientFromContext(ctx, r.client)
if err := client.ChannelMonitor.UpdateOneID(id).
SetLastCheckedAt(checkedAt).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
}
return nil
}
func (r *channelMonitorRepository) InsertHistoryBatch(ctx context.Context, rows []*service.ChannelMonitorHistoryRow) error {
if len(rows) == 0 {
return nil
}
client := clientFromContext(ctx, r.client)
bulk := make([]*dbent.ChannelMonitorHistoryCreate, 0, len(rows))
for _, row := range rows {
c := client.ChannelMonitorHistory.Create().
SetMonitorID(row.MonitorID).
SetModel(row.Model).
SetStatus(channelmonitorhistory.Status(row.Status)).
SetMessage(row.Message).
SetCheckedAt(row.CheckedAt)
if row.LatencyMs != nil {
c = c.SetLatencyMs(*row.LatencyMs)
}
if row.PingLatencyMs != nil {
c = c.SetPingLatencyMs(*row.PingLatencyMs)
}
bulk = append(bulk, c)
}
if _, err := client.ChannelMonitorHistory.CreateBulk(bulk...).Save(ctx); err != nil {
return fmt.Errorf("insert history bulk: %w", err)
}
return nil
}
// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批,
// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。
func (r *channelMonitorRepository) DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) {
return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneHistorySQL, before)
}
// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。
// model 为空时不过滤;非空时只返回该模型的记录。
func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*service.ChannelMonitorHistoryEntry, error) {
q := r.client.ChannelMonitorHistory.Query().
Where(channelmonitorhistory.MonitorIDEQ(monitorID))
if strings.TrimSpace(model) != "" {
q = q.Where(channelmonitorhistory.ModelEQ(model))
}
rows, err := q.
Order(dbent.Desc(channelmonitorhistory.FieldCheckedAt)).
Limit(limit).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list history: %w", err)
}
out := make([]*service.ChannelMonitorHistoryEntry, 0, len(rows))
for _, row := range rows {
entry := &service.ChannelMonitorHistoryEntry{
ID: row.ID,
Model: row.Model,
Status: string(row.Status),
LatencyMs: row.LatencyMs,
PingLatencyMs: row.PingLatencyMs,
Message: row.Message,
CheckedAt: row.CheckedAt,
}
out = append(out, entry)
}
return out, nil
}
// ---------- 用户视图聚合(原生 SQL) ----------
// ListLatestPerModel 用 DISTINCT ON 取每个 (monitor_id, model) 的最近一条记录。
// 借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) {
const q = `
SELECT DISTINCT ON (model)
model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = $1
ORDER BY model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, monitorID)
if err != nil {
return nil, fmt.Errorf("query latest per model: %w", err)
}
defer func() { _ = rows.Close() }()
out := make([]*service.ChannelMonitorLatest, 0)
for rows.Next() {
l := &service.ChannelMonitorLatest{}
var latency, ping sql.NullInt64
if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
return nil, fmt.Errorf("scan latest row: %w", err)
}
assignNullInt(&l.LatencyMs, latency)
assignNullInt(&l.PingLatencyMs, ping)
out = append(out, l)
}
return out, rows.Err()
}
// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。
// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。
func assignNullInt(dst **int, n sql.NullInt64) {
if !n.Valid {
return
}
v := int(n.Int64)
*dst = &v
}
// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
// "可用" = status IN (operational, degraded)。
//
// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories,
// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。
func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
if windowDays <= 0 {
windowDays = 7
}
const q = `
SELECT model,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
CASE WHEN COUNT(latency_ms) > 0
THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
ELSE NULL END AS avg_latency_ms
FROM channel_monitor_histories
WHERE monitor_id = $1
AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY model
`
rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays)
if err != nil {
return nil, fmt.Errorf("query availability: %w", err)
}
defer func() { _ = rows.Close() }()
out := make([]*service.ChannelMonitorAvailability, 0)
for rows.Next() {
row, err := scanAvailabilityRow(rows, windowDays)
if err != nil {
return nil, err
}
out = append(out, row)
}
return out, rows.Err()
}
// scanAvailabilityRow 把单行 (model, total, ok, avg_latency) 扫描为 ChannelMonitorAvailability。
// 仅服务于 ComputeAvailability(4 列);批量版本因为多一列 monitor_id 直接 inline 调 finalizeAvailabilityRow。
func scanAvailabilityRow(rows interface{ Scan(...any) error }, windowDays int) (*service.ChannelMonitorAvailability, error) {
row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
var avgLatency sql.NullFloat64
if err := rows.Scan(&row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
return nil, fmt.Errorf("scan availability row: %w", err)
}
finalizeAvailabilityRow(row, avgLatency)
return row, nil
}
// finalizeAvailabilityRow 根据 OperationalChecks/TotalChecks 算出可用率,
// 并把 sql.NullFloat64 的平均延迟解包为 *int。两处复用避免维护漂移。
func finalizeAvailabilityRow(row *service.ChannelMonitorAvailability, avgLatency sql.NullFloat64) {
if row.TotalChecks > 0 {
row.AvailabilityPct = float64(row.OperationalChecks) * 100.0 / float64(row.TotalChecks)
}
if avgLatency.Valid {
v := int(avgLatency.Float64)
row.AvgLatencyMs = &v
}
}
// ListLatestForMonitorIDs 一次性查询多个监控的"每个 (monitor_id, model) 最近一条"记录。
// 利用 PG 的 DISTINCT ON 特性,借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*service.ChannelMonitorLatest, error) {
out := make(map[int64][]*service.ChannelMonitorLatest, len(ids))
if len(ids) == 0 {
return out, nil
}
const q = `
SELECT DISTINCT ON (monitor_id, model)
monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
ORDER BY monitor_id, model, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(ids))
if err != nil {
return nil, fmt.Errorf("query latest batch: %w", err)
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var monitorID int64
l := &service.ChannelMonitorLatest{}
var latency, ping sql.NullInt64
if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
return nil, fmt.Errorf("scan latest batch row: %w", err)
}
assignNullInt(&l.LatencyMs, latency)
assignNullInt(&l.PingLatencyMs, ping)
out[monitorID] = append(out[monitorID], l)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。
// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。
// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单,
// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。
//
// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。
// 空 ids / 空 primaryModels 返回空 map,不报错。
func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
ctx context.Context,
ids []int64,
primaryModels map[int64]string,
perMonitorLimit int,
) (map[int64][]*service.ChannelMonitorHistoryEntry, error) {
out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids))
pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels)
if len(pairIDs) == 0 {
return out, nil
}
perMonitorLimit = clampTimelineLimit(perMonitorLimit)
const q = `
WITH targets AS (
SELECT unnest($1::bigint[]) AS monitor_id,
unnest($2::text[]) AS model
),
ranked AS (
SELECT h.monitor_id,
h.status,
h.latency_ms,
h.ping_latency_ms,
h.checked_at,
ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn
FROM channel_monitor_histories h
JOIN targets t
ON t.monitor_id = h.monitor_id AND t.model = h.model
)
SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
FROM ranked
WHERE rn <= $3
ORDER BY monitor_id, checked_at DESC
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit)
if err != nil {
return nil, fmt.Errorf("query recent history batch: %w", err)
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var monitorID int64
entry := &service.ChannelMonitorHistoryEntry{}
var latency, ping sql.NullInt64
if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil {
return nil, fmt.Errorf("scan recent history row: %w", err)
}
assignNullInt(&entry.LatencyMs, latency)
assignNullInt(&entry.PingLatencyMs, ping)
out[monitorID] = append(out[monitorID], entry)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。
// 保证两个数组长度一致且一一对应,供 unnest 展开。
func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) {
if len(ids) == 0 || len(primaryModels) == 0 {
return nil, nil
}
pairIDs := make([]int64, 0, len(ids))
pairModels := make([]string, 0, len(ids))
for _, id := range ids {
model, ok := primaryModels[id]
if !ok || strings.TrimSpace(model) == "" {
continue
}
pairIDs = append(pairIDs, id)
pairModels = append(pairModels, model)
}
return pairIDs, pairModels
}
// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。
// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。
const (
timelineLimitMin = 1
timelineLimitMax = 200
)
// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。
func clampTimelineLimit(n int) int {
if n < timelineLimitMin {
return timelineLimitMin
}
if n > timelineLimitMax {
return timelineLimitMax
}
return n
}
// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。
func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
if len(ids) == 0 {
return out, nil
}
if windowDays <= 0 {
windowDays = 7
}
const q = `
SELECT monitor_id,
model,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
CASE WHEN COUNT(latency_ms) > 0
THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
ELSE NULL END AS avg_latency_ms
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY monitor_id, model
`
rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays)
if err != nil {
return nil, fmt.Errorf("query availability batch: %w", err)
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var monitorID int64
row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
var avgLatency sql.NullFloat64
if err := rows.Scan(&monitorID, &row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
return nil, fmt.Errorf("scan availability batch row: %w", err)
}
// 批量查询多了首列 monitor_id;其余字段的可用率/平均延迟换算与单 monitor 版本一致,
// 抽出 finalizeAvailabilityRow 复用,避免两处分别维护除法与 NullFloat 解包。
finalizeAvailabilityRow(row, avgLatency)
out[monitorID] = append(out[monitorID], row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// ---------- 聚合维护 ----------
// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细
// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
// 重复执行只会用最新统计覆盖;
// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) {
const q = `
INSERT INTO channel_monitor_daily_rollups (
monitor_id, model, bucket_date,
total_checks, ok_count,
operational_count, degraded_count, failed_count, error_count,
sum_latency_ms, count_latency,
sum_ping_latency_ms, count_ping_latency,
computed_at
)
SELECT
monitor_id,
model,
$1::date AS bucket_date,
COUNT(*) AS total_checks,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
COUNT(*) FILTER (WHERE status = 'operational') AS operational_count,
COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count,
COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
COUNT(*) FILTER (WHERE status = 'error') AS error_count,
COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
COUNT(latency_ms) AS count_latency,
COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms,
COUNT(ping_latency_ms) AS count_ping_latency,
NOW()
FROM channel_monitor_histories
WHERE checked_at >= $1::date
AND checked_at < ($1::date + INTERVAL '1 day')
GROUP BY monitor_id, model
ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
total_checks = EXCLUDED.total_checks,
ok_count = EXCLUDED.ok_count,
operational_count = EXCLUDED.operational_count,
degraded_count = EXCLUDED.degraded_count,
failed_count = EXCLUDED.failed_count,
error_count = EXCLUDED.error_count,
sum_latency_ms = EXCLUDED.sum_latency_ms,
count_latency = EXCLUDED.count_latency,
sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
count_ping_latency = EXCLUDED.count_ping_latency,
computed_at = NOW()
`
res, err := r.db.ExecContext(ctx, q, targetDate)
if err != nil {
return 0, fmt.Errorf("upsert daily rollups for %s: %w", targetDate.Format("2006-01-02"), err)
}
n, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("rows affected (upsert rollups): %w", err)
}
return n, nil
}
// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。
func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) {
return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneRollupSQL, beforeDate)
}
// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000,
// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。
const channelMonitorPruneBatchSize = 5000
// channelMonitorPruneHistorySQL 分批物理删明细表过期行。
const channelMonitorPruneHistorySQL = `
WITH batch AS (
SELECT id FROM channel_monitor_histories
WHERE checked_at < $1
ORDER BY id
LIMIT $2
)
DELETE FROM channel_monitor_histories
WHERE id IN (SELECT id FROM batch)
`
// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型
// 保证与 DATE 列一致比较。
const channelMonitorPruneRollupSQL = `
WITH batch AS (
SELECT id FROM channel_monitor_daily_rollups
WHERE bucket_date < $1::date
ORDER BY id
LIMIT $2
)
DELETE FROM channel_monitor_daily_rollups
WHERE id IN (SELECT id FROM batch)
`
// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。
// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。
func deleteChannelMonitorBatched(ctx context.Context, db *sql.DB, query string, cutoff time.Time) (int64, error) {
var total int64
for {
res, err := db.ExecContext(ctx, query, cutoff, channelMonitorPruneBatchSize)
if err != nil {
return total, fmt.Errorf("channel_monitor prune batch: %w", err)
}
affected, err := res.RowsAffected()
if err != nil {
return total, fmt.Errorf("channel_monitor prune rows affected: %w", err)
}
total += affected
if affected == 0 {
break
}
}
return total, nil
}
// LoadAggregationWatermark 读 watermark 表(id=1)。
// watermark 表不是 ent schema(只有一行),直接走原生 SQL。
// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略
func (r *channelMonitorRepository) LoadAggregationWatermark(ctx context.Context) (*time.Time, error) {
const q = `SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1`
var t sql.NullTime
if err := r.db.QueryRowContext(ctx, q).Scan(&t); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("load aggregation watermark: %w", err)
}
if !t.Valid {
return nil, nil
}
return &t.Time, nil
}
// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。
// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。
func (r *channelMonitorRepository) UpdateAggregationWatermark(ctx context.Context, date time.Time) error {
const q = `
INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
VALUES (1, $1::date, NOW())
ON CONFLICT (id) DO UPDATE SET
last_aggregated_date = EXCLUDED.last_aggregated_date,
updated_at = NOW()
`
if _, err := r.db.ExecContext(ctx, q, date); err != nil {
return fmt.Errorf("update aggregation watermark: %w", err)
}
return nil
}
// ---------- helpers ----------
func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
if row == nil {
return nil
}
extras := row.ExtraModels
if extras == nil {
extras = []string{}
}
headers := row.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
out := &service.ChannelMonitor{
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
Endpoint: row.Endpoint,
APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
PrimaryModel: row.PrimaryModel,
ExtraModels: extras,
GroupName: row.GroupName,
Enabled: row.Enabled,
IntervalSeconds: row.IntervalSeconds,
LastCheckedAt: row.LastCheckedAt,
CreatedBy: row.CreatedBy,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
ExtraHeaders: headers,
BodyOverrideMode: row.BodyOverrideMode,
BodyOverride: row.BodyOverride,
}
if row.TemplateID != nil {
id := *row.TemplateID
out.TemplateID = &id
}
return out
}
// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致,
// repo 独立一份避免 import 循环。
func emptyHeadersIfNilRepo(h map[string]string) map[string]string {
if h == nil {
return map[string]string{}
}
return h
}
// defaultBodyModeRepo 空串归一为 off(同上不循环)。
func defaultBodyModeRepo(mode string) string {
if mode == "" {
return "off"
}
return mode
}
func emptySliceIfNil(in []string) []string {
if in == nil {
return []string{}
}
return in
}
package repository
import (
"context"
"database/sql"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。
// 与 channelMonitorRepository 分开一个文件,职责清晰。
type channelMonitorRequestTemplateRepository struct {
client *dbent.Client
db *sql.DB
}
// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。
func NewChannelMonitorRequestTemplateRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRequestTemplateRepository {
return &channelMonitorRequestTemplateRepository{client: client, db: db}
}
// ---------- CRUD ----------
func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
client := clientFromContext(ctx, r.client)
builder := client.ChannelMonitorRequestTemplate.Create().
SetName(t.Name).
SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)).
SetDescription(t.Description).
SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
if t.BodyOverride != nil {
builder = builder.SetBodyOverride(t.BodyOverride)
}
created, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
}
t.ID = created.ID
t.CreatedAt = created.CreatedAt
t.UpdatedAt = created.UpdatedAt
return nil
}
func (r *channelMonitorRequestTemplateRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitorRequestTemplate, error) {
row, err := r.client.ChannelMonitorRequestTemplate.Query().
Where(channelmonitorrequesttemplate.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
}
return entToServiceTemplate(row), nil
}
func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
client := clientFromContext(ctx, r.client)
updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID).
SetName(t.Name).
SetDescription(t.Description).
SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
if t.BodyOverride != nil {
updater = updater.SetBodyOverride(t.BodyOverride)
} else {
updater = updater.ClearBodyOverride()
}
updated, err := updater.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
}
t.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *channelMonitorRequestTemplateRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
if err := client.ChannelMonitorRequestTemplate.DeleteOneID(id).Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
}
return nil
}
func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, params service.ChannelMonitorRequestTemplateListParams) ([]*service.ChannelMonitorRequestTemplate, error) {
q := r.client.ChannelMonitorRequestTemplate.Query()
if params.Provider != "" {
q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider)))
}
rows, err := q.
Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list monitor templates: %w", err)
}
out := make([]*service.ChannelMonitorRequestTemplate, 0, len(rows))
for _, row := range rows {
out = append(out, entToServiceTemplate(row))
}
return out, nil
}
// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。
// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id
// 就被覆盖。走 ent UpdateMany 保留 hooks。
func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
if len(monitorIDs) == 0 {
return 0, nil
}
client := clientFromContext(ctx, r.client)
tpl, err := client.ChannelMonitorRequestTemplate.Query().
Where(channelmonitorrequesttemplate.IDEQ(id)).
Only(ctx)
if err != nil {
return 0, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
}
updater := client.ChannelMonitor.Update().
Where(
channelmonitor.TemplateIDEQ(id),
channelmonitor.IDIn(monitorIDs...),
).
SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
if tpl.BodyOverride != nil {
updater = updater.SetBodyOverride(tpl.BodyOverride)
} else {
updater = updater.ClearBodyOverride()
}
affected, err := updater.Save(ctx)
if err != nil {
return 0, fmt.Errorf("apply template to monitors: %w", err)
}
return int64(affected), nil
}
// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。
func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
count, err := r.client.ChannelMonitor.Query().
Where(channelmonitor.TemplateIDEQ(id)).
Count(ctx)
if err != nil {
return 0, fmt.Errorf("count monitors for template %d: %w", id, err)
}
return int64(count), nil
}
// ListAssociatedMonitors 列出模板关联的所有监控简略字段。
// ORDER BY name 稳定输出方便前端展示。
func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx context.Context, id int64) ([]*service.AssociatedMonitorBrief, error) {
rows, err := r.client.ChannelMonitor.Query().
Where(channelmonitor.TemplateIDEQ(id)).
Order(dbent.Asc(channelmonitor.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list associated monitors for template %d: %w", id, err)
}
out := make([]*service.AssociatedMonitorBrief, 0, len(rows))
for _, row := range rows {
out = append(out, &service.AssociatedMonitorBrief{
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
Enabled: row.Enabled,
})
}
return out, nil
}
// ---------- helpers ----------
func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate {
if row == nil {
return nil
}
headers := row.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
return &service.ChannelMonitorRequestTemplate{
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
Description: row.Description,
ExtraHeaders: headers,
BodyOverrideMode: row.BodyOverrideMode,
BodyOverride: row.BodyOverride,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
}
}
......@@ -89,6 +89,8 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
NewChannelMonitorRepository,
NewChannelMonitorRequestTemplateRepository,
// Cache implementations
NewGatewayCache,
......
......@@ -771,6 +771,9 @@ func TestAPIContracts(t *testing.T) {
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
......@@ -943,6 +946,9 @@ func TestAPIContracts(t *testing.T) {
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
......
......@@ -88,6 +88,9 @@ func RegisterAdminRoutes(
// 渠道管理
registerChannelRoutes(admin, h)
// 渠道监控
registerChannelMonitorRoutes(admin, h)
}
}
......@@ -567,3 +570,27 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}
func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
monitors := admin.Group("/channel-monitors")
{
monitors.GET("", h.Admin.ChannelMonitor.List)
monitors.POST("", h.Admin.ChannelMonitor.Create)
monitors.GET("/:id", h.Admin.ChannelMonitor.Get)
monitors.PUT("/:id", h.Admin.ChannelMonitor.Update)
monitors.DELETE("/:id", h.Admin.ChannelMonitor.Delete)
monitors.POST("/:id/run", h.Admin.ChannelMonitor.Run)
monitors.GET("/:id/history", h.Admin.ChannelMonitor.History)
}
templates := admin.Group("/channel-monitor-templates")
{
templates.GET("", h.Admin.ChannelMonitorTemplate.List)
templates.POST("", h.Admin.ChannelMonitorTemplate.Create)
templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get)
templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update)
templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete)
templates.GET("/:id/monitors", h.Admin.ChannelMonitorTemplate.AssociatedMonitors)
templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
}
}
......@@ -68,6 +68,12 @@ func RegisterUserRoutes(
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
// 用户可用渠道(非管理员接口)
channels := authenticated.Group("/channels")
{
channels.GET("/available", h.AvailableChannel.List)
}
// 使用记录
usage := authenticated.Group("/usage")
{
......@@ -103,5 +109,12 @@ func RegisterUserRoutes(
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
// 渠道监控(用户只读)
monitors := authenticated.Group("/channel-monitors")
{
monitors.GET("", h.ChannelMonitor.List)
monitors.GET("/:id/status", h.ChannelMonitor.GetStatus)
}
}
}
......@@ -111,6 +111,18 @@ func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
// 作为 *Channel 的实体方法集中管理默认值,service 层只需在 Channel 进入内存
// (缓存装填、repo 读出)时调用一次,下游读路径就无需重复兜底。
func (c *Channel) normalizeBillingModelSource() {
if c == nil {
return
}
if c.BillingModelSource == "" {
c.BillingModelSource = BillingModelSourceChannelMapped
}
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
......@@ -345,3 +357,209 @@ type ChannelUsageFields struct {
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
type SupportedModel struct {
Name string // 用户侧模型名
Platform string // 所属平台
Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价)
}
// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
const wildcardSuffix = "*"
// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
//
// "claude-opus-*" → ("claude-opus-", true)
// "claude-opus-4" → ("claude-opus-4", false)
// "*" → ("", true)
//
// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) {
if strings.HasSuffix(pattern, wildcardSuffix) {
return strings.TrimSuffix(pattern, wildcardSuffix), true
}
return pattern, false
}
// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing {
if c == nil {
return nil
}
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
if c.ModelPricing[i].Platform != platform {
continue
}
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// platformPricingIndex 是单个平台下定价信息的复合索引。
// 一次扫描即可同时支持精确查找(exact 分支)与有序遍历(wildcard 分支),
// 避免 SupportedModels 对每个平台重复扫描定价列表。
//
// byLower 与 names/originalCase 共享同一套去重规则:以 lower-case 模型名为 key,
// 首个命中保留其原始大小写。names 维持按定价行扫描顺序的稳定迭代。
type platformPricingIndex struct {
byLower map[string]*ChannelModelPricing // lowercased model name → pricing (Clone'd)
originalCase map[string]string // lowercased model name → original-case model name
names []string // priced model names in their ORIGINAL case, insertion-ordered, deduped case-insensitively (first wins)
}
// buildPricingIndex 对渠道的定价列表做一次扫描,按 platform 聚合为查找索引。
// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
// 通配符后缀条目(如 "claude-*")不被索引(它们是模式,不是具体模型名)。
// 同一平台中以大小写不敏感方式去重,先出现者保留原始大小写。
func buildPricingIndex(pricings []ChannelModelPricing) map[string]*platformPricingIndex {
idx := make(map[string]*platformPricingIndex)
for i := range pricings {
p := pricings[i]
pidx, ok := idx[p.Platform]
if !ok {
pidx = &platformPricingIndex{
byLower: make(map[string]*ChannelModelPricing),
originalCase: make(map[string]string),
names: make([]string, 0),
}
idx[p.Platform] = pidx
}
for _, m := range p.Models {
if _, wild := splitWildcardSuffix(m); wild {
continue
}
lower := strings.ToLower(m)
if _, exists := pidx.byLower[lower]; exists {
continue // 首个命中胜出(case-insensitive 去重后第一个定价 / 第一个原始大小写)
}
cp := pricings[i].Clone()
pidx.byLower[lower] = &cp
pidx.originalCase[lower] = m
pidx.names = append(pidx.names, m)
}
}
return idx
}
// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
//
// 算法(mapping ∪ pricing 并联):
//
// - Pass A(mapping):遍历 ModelMapping
// - 精确 src → target:显示名 = src(用户视角),定价用 target 在同 platform 定价里查
// (mapping 改写后实际计费的是 target;这是用户感知的"实际花费")。
// target 为空或为通配符时退化为按 src 自查。
// - 通配符 src(如 "claude-3-*"):用同 platform 定价里前缀匹配的模型作为候选展开,
// 每个候选用自身定价(通配符场景一般是 passthrough,target 通常也是通配符)。
// - "*" 单独 mapping key 走通配符分支(前缀为空 → 全展开)。
// - Pass B(pricing-only):遍历 ModelPricing 中所有非通配符模型,对未在 Pass A 添加过的
// 补齐——显示名 = 定价模型名,定价 = 自身(这是关键修复:定价存在即代表渠道支持该模型,
// 即使没配映射)。
//
// 显示名命中定价时使用**定价的原始大小写**(定价是模型身份的事实来源)。
// 按 (Platform, Name) 稳定排序,按 (Platform, lowercase(Name)) 去重,先到者胜出。
//
// 注意:定价仅在 channel.ModelPricing 内查找——全局 LiteLLM 回落由调用方
// (`ChannelService.ListAvailable`)在合成展示数据时叠加。
func (c *Channel) SupportedModels() []SupportedModel {
if c == nil {
return nil
}
if len(c.ModelMapping) == 0 && len(c.ModelPricing) == 0 {
return nil
}
idx := buildPricingIndex(c.ModelPricing)
type dedupKey struct {
platform string
name string
}
seen := make(map[dedupKey]struct{})
result := make([]SupportedModel, 0)
// lookup 在 platform pricing index 中按精确名查定价,命中时返回定价大小写。
lookup := func(pidx *platformPricingIndex, name string) (display string, pricing *ChannelModelPricing) {
if pidx == nil || name == "" {
return name, nil
}
lower := strings.ToLower(name)
if p, ok := pidx.byLower[lower]; ok {
return pidx.originalCase[lower], p
}
return name, nil
}
add := func(platform, displayName string, pricing *ChannelModelPricing) {
key := dedupKey{platform: platform, name: strings.ToLower(displayName)}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
result = append(result, SupportedModel{
Name: displayName,
Platform: platform,
Pricing: pricing,
})
}
// Pass A:从 mapping 展开
for platform, mapping := range c.ModelMapping {
if len(mapping) == 0 {
continue
}
pidx := idx[platform]
for src, target := range mapping {
prefix, isWild := splitWildcardSuffix(src)
if isWild {
if pidx == nil {
continue
}
prefixLower := strings.ToLower(prefix)
for _, candidate := range pidx.names {
if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
display, pricing := lookup(pidx, candidate)
add(platform, display, pricing)
}
}
continue
}
// 精确 mapping:定价按 target 查;target 缺失/通配则退化按 src 查
pricingKey := target
if pricingKey == "" {
pricingKey = src
}
if _, targetWild := splitWildcardSuffix(pricingKey); targetWild {
pricingKey = src
}
_, pricing := lookup(pidx, pricingKey)
// 显示名优先用 src 在定价里的原始大小写(若 src 本身是个定价模型名)
displayName, _ := lookup(pidx, src)
add(platform, displayName, pricing)
}
}
// Pass B:从 pricing 补齐 mapping 未覆盖的具体模型(修复"定价存在但没配映射 → 不显示")
for platform, pidx := range idx {
for _, name := range pidx.names {
display, pricing := lookup(pidx, name)
add(platform, display, pricing)
}
}
sort.SliceStable(result, func(i, j int) bool {
if result[i].Platform != result[j].Platform {
return result[i].Platform < result[j].Platform
}
return result[i].Name < result[j].Name
})
return result
}
package service
import (
"context"
"fmt"
"sort"
"strings"
)
// AvailableGroupRef 渠道视图中关联分组的简要信息。
//
// 用户侧「可用渠道」页面据此展示:专属分组 vs 公开分组(IsExclusive)、
// 订阅 vs 标准(SubscriptionType)、默认倍率(RateMultiplier)。用户专属倍率
// 不在这里暴露,前端自己通过 /groups/rates 拉取,和 API 密钥页面保持一致。
type AvailableGroupRef struct {
ID int64
Name string
Platform string
SubscriptionType string
RateMultiplier float64
IsExclusive bool
}
// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
// 关联的分组 + 推导出的支持模型列表(无通配符)。
type AvailableChannel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string
RestrictModels bool
Groups []AvailableGroupRef
SupportedModels []SupportedModel
}
// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
//
// 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。
// 对于渠道未配置定价的模型,进一步用 PricingService 的全局 LiteLLM 数据合成
// 一份展示用定价,让用户看到默认价格而非"未配置"。
//
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
// 的分组(已停用或删除)会被忽略。
//
// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast,
// 避免静默掩盖注入缺失。
func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
return nil, fmt.Errorf("list channels: %w", err)
}
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
groupByID := make(map[int64]AvailableGroupRef, len(groups))
for i := range groups {
g := groups[i]
groupByID[g.ID] = AvailableGroupRef{
ID: g.ID,
Name: g.Name,
Platform: g.Platform,
SubscriptionType: g.SubscriptionType,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
}
}
out := make([]AvailableChannel, 0, len(channels))
for i := range channels {
ch := &channels[i]
groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
for _, gid := range ch.GroupIDs {
if ref, ok := groupByID[gid]; ok {
groups = append(groups, ref)
}
}
sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
ch.normalizeBillingModelSource()
supported := ch.SupportedModels()
s.fillGlobalPricingFallback(supported)
out = append(out, AvailableChannel{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
BillingModelSource: ch.BillingModelSource,
RestrictModels: ch.RestrictModels,
Groups: groups,
SupportedModels: supported,
})
}
sort.SliceStable(out, func(i, j int) bool {
return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
})
return out, nil
}
// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份
// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。
//
// 当 s.pricingService 为 nil(测试场景),跳过回落。
func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) {
if s.pricingService == nil {
return
}
for i := range models {
if models[i].Pricing != nil {
continue
}
lp := s.pricingService.GetModelPricing(models[i].Name)
if lp == nil {
continue
}
models[i].Pricing = synthesizePricingFromLiteLLM(lp)
}
}
// synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态,
// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到
// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。
//
// LiteLLM 中字段 0 视为未配置,不带入展示。
func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing) *ChannelModelPricing {
if lp == nil {
return nil
}
return &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: nonZeroPtr(lp.InputCostPerToken),
OutputPrice: nonZeroPtr(lp.OutputCostPerToken),
CacheWritePrice: nonZeroPtr(lp.CacheCreationInputTokenCost),
CacheReadPrice: nonZeroPtr(lp.CacheReadInputTokenCost),
ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken),
}
}
func nonZeroPtr(v float64) *float64 {
if v == 0 {
return nil
}
return &v
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。
// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
type stubGroupRepoForAvailable struct {
activeGroups []Group
listActiveErr error
listActiveCalls int
}
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
s.listActiveCalls++
if s.listActiveErr != nil {
return nil, s.listActiveErr
}
return s.activeGroups, nil
}
func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
// groupRepo 由参数决定。传入空 stub 表示「活跃分组列表为空」。
func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
}
return NewChannelService(repo, groupRepo, nil, nil)
}
func TestListAvailable_EmptyActiveGroups_NoGroupsAttached(t *testing.T) {
// 活跃分组列表为空时,渠道的 Groups 应为空切片,不报错。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{10, 20},
}}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Empty(t, out[0].Groups)
}
func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
// 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
channels := []Channel{{
ID: 1,
Name: "chA",
Status: StatusActive,
GroupIDs: []int64{1, 99},
}}
groupRepo := &stubGroupRepoForAvailable{
activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
}
svc := newAvailableChannelService(channels, groupRepo)
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 1)
require.Len(t, out[0].Groups, 1)
require.Equal(t, int64(1), out[0].Groups[0].ID)
}
func TestListAvailable_SortedByName(t *testing.T) {
channels := []Channel{
{ID: 1, Name: "beta"},
{ID: 2, Name: "Alpha"},
{ID: 3, Name: "charlie"},
}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 3)
require.Equal(t, "Alpha", out[0].Name)
require.Equal(t, "beta", out[1].Name)
require.Equal(t, "charlie", out[2].Name)
}
func TestListAvailable_ListAllErrorPropagates(t *testing.T) {
// ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。
sentinel := errors.New("list-all-boom")
repo := &mockChannelRepository{
listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel },
}
groupRepo := &stubGroupRepoForAvailable{}
svc := NewChannelService(repo, groupRepo, nil, nil)
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v")
require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive")
}
func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
// groupRepo.ListActive 返回错误时 ListAvailable 应直接返回包装后的错误。
sentinel := errors.New("list-active-boom")
svc := newAvailableChannelService(
[]Channel{{ID: 1, Name: "chA"}},
&stubGroupRepoForAvailable{listActiveErr: sentinel},
)
out, err := svc.ListAvailable(context.Background())
require.Nil(t, out)
require.ErrorIs(t, err, sentinel)
require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v")
}
func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
// 渠道 BillingModelSource 为空时应回填为 BillingModelSourceChannelMapped,
// 显式值应原样保留(由 service 层统一处理,避免各 handler 重复默认逻辑)。
channels := []Channel{
{ID: 1, Name: "empty", BillingModelSource: ""},
{ID: 2, Name: "explicit", BillingModelSource: BillingModelSourceUpstream},
}
svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
out, err := svc.ListAvailable(context.Background())
require.NoError(t, err)
require.Len(t, out, 2)
// 按 Name 查找,避免依赖排序副作用。
byName := make(map[string]string, len(out))
for _, ch := range out {
byName[ch.Name] = ch.BillingModelSource
}
require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
}
package service
import (
"context"
"fmt"
"log/slog"
)
// 渠道监控聚合层:把 latest + availability 拼成 admin/user 视图所需的 summary / detail。
// 所有方法都遵守"失败仅日志,返回零值"的原则,避免 N+1 查询失败拖垮列表渲染。
// BatchMonitorStatusSummary 批量聚合多个监控的 latest + 7d 可用率(admin/user list 用,消除 N+1)。
// 失败时返回空 map,错误仅日志,不影响列表渲染。
//
// 参数:
// - ids: 要聚合的 monitor ID 列表
// - primaryByID: monitor ID -> primary model(用于读 7d 可用率与 latest 状态)
// - extrasByID: monitor ID -> extra models 列表(用于读 latest 状态填充 ExtraModels)
func (s *ChannelMonitorService) BatchMonitorStatusSummary(
ctx context.Context,
ids []int64,
primaryByID map[int64]string,
extrasByID map[int64][]string,
) map[int64]MonitorStatusSummary {
out := make(map[int64]MonitorStatusSummary, len(ids))
if len(ids) == 0 {
return out
}
latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids)
if err != nil {
slog.Warn("channel_monitor: batch load latest failed", "error", err)
latestMap = map[int64][]*ChannelMonitorLatest{}
}
availMap, err := s.repo.ComputeAvailabilityForMonitors(ctx, ids, monitorAvailability7Days)
if err != nil {
slog.Warn("channel_monitor: batch compute availability failed", "error", err)
availMap = map[int64][]*ChannelMonitorAvailability{}
}
for _, id := range ids {
out[id] = buildStatusSummary(
indexLatestByModel(latestMap[id]),
indexAvailabilityByModel(availMap[id]),
primaryByID[id],
extrasByID[id],
)
}
return out
}
// ListUserView 用户只读视图:列出所有 enabled 监控的概览。
// 使用批量聚合接口避免 N+1:
//
// 1 次查 monitors;
// 1 次批量 latest(含 ping_latency_ms);
// 1 次批量 7d availability;
// 1 次批量 timeline(主模型最近 N 条)。
func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonitorView, error) {
monitors, err := s.repo.ListEnabled(ctx)
if err != nil {
return nil, fmt.Errorf("list enabled monitors: %w", err)
}
if len(monitors) == 0 {
return []*UserMonitorView{}, nil
}
ids, primaryByID, extrasByID := collectMonitorIndexes(monitors)
summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID)
latestMap := s.batchLatest(ctx, ids)
timelineMap := s.batchTimeline(ctx, ids, primaryByID)
views := make([]*UserMonitorView, 0, len(monitors))
for _, m := range monitors {
primaryLatest := pickLatest(latestMap[m.ID], m.PrimaryModel)
views = append(views, buildUserViewFromSummary(m, summaries[m.ID], primaryLatest, timelineMap[m.ID]))
}
return views, nil
}
// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。
func collectMonitorIndexes(monitors []*ChannelMonitor) ([]int64, map[int64]string, map[int64][]string) {
ids := make([]int64, 0, len(monitors))
primaryByID := make(map[int64]string, len(monitors))
extrasByID := make(map[int64][]string, len(monitors))
for _, m := range monitors {
ids = append(ids, m.ID)
primaryByID[m.ID] = m.PrimaryModel
extrasByID[m.ID] = m.ExtraModels
}
return ids, primaryByID, extrasByID
}
// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。
func (s *ChannelMonitorService) batchLatest(ctx context.Context, ids []int64) map[int64][]*ChannelMonitorLatest {
latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids)
if err != nil {
slog.Warn("channel_monitor: user view batch latest failed", "error", err)
return map[int64][]*ChannelMonitorLatest{}
}
return latestMap
}
// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。
func (s *ChannelMonitorService) batchTimeline(
ctx context.Context,
ids []int64,
primaryByID map[int64]string,
) map[int64][]*ChannelMonitorHistoryEntry {
timelineMap, err := s.repo.ListRecentHistoryForMonitors(ctx, ids, primaryByID, monitorTimelineMaxPoints)
if err != nil {
slog.Warn("channel_monitor: user view batch timeline failed", "error", err)
return map[int64][]*ChannelMonitorHistoryEntry{}
}
return timelineMap
}
// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。
func pickLatest(rows []*ChannelMonitorLatest, model string) *ChannelMonitorLatest {
if model == "" {
return nil
}
for _, r := range rows {
if r.Model == model {
return r
}
}
return nil
}
// GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。
// 不暴露 api_key。
func (s *ChannelMonitorService) GetUserDetail(ctx context.Context, id int64) (*UserMonitorDetail, error) {
m, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if !m.Enabled {
return nil, ErrChannelMonitorNotFound
}
latest, err := s.repo.ListLatestPerModel(ctx, id)
if err != nil {
return nil, fmt.Errorf("list latest per model: %w", err)
}
availMap, err := s.collectAvailabilityWindows(ctx, id)
if err != nil {
return nil, err
}
models := mergeModelDetails(m, latest, availMap)
return &UserMonitorDetail{
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
GroupName: m.GroupName,
Models: models,
}, nil
}
// collectAvailabilityWindows 一次性查询 7/15/30 天三个窗口,按模型组织。
func (s *ChannelMonitorService) collectAvailabilityWindows(ctx context.Context, monitorID int64) (map[int]map[string]*ChannelMonitorAvailability, error) {
out := make(map[int]map[string]*ChannelMonitorAvailability, 3)
windows := []int{monitorAvailability7Days, monitorAvailability15Days, monitorAvailability30Days}
for _, w := range windows {
rows, err := s.repo.ComputeAvailability(ctx, monitorID, w)
if err != nil {
return nil, fmt.Errorf("compute availability %dd: %w", w, err)
}
out[w] = indexAvailabilityByModel(rows)
}
return out, nil
}
// ---------- 纯函数 helper(无 IO,可在 batch / 单 monitor / detail 路径复用)----------
// indexLatestByModel 把 latest 切片按 model 索引(小工具,避免在 hot path 重复写)。
func indexLatestByModel(rows []*ChannelMonitorLatest) map[string]*ChannelMonitorLatest {
m := make(map[string]*ChannelMonitorLatest, len(rows))
for _, r := range rows {
m[r.Model] = r
}
return m
}
// indexAvailabilityByModel 把 availability 切片按 model 索引。
func indexAvailabilityByModel(rows []*ChannelMonitorAvailability) map[string]*ChannelMonitorAvailability {
m := make(map[string]*ChannelMonitorAvailability, len(rows))
for _, r := range rows {
m[r.Model] = r
}
return m
}
// buildStatusSummary 由 latest + availability 字典构造 MonitorStatusSummary。
// 不做任何 IO,纯组装,便于在 batch 与单 monitor 路径复用。
func buildStatusSummary(
latestByModel map[string]*ChannelMonitorLatest,
availByModel map[string]*ChannelMonitorAvailability,
primary string,
extras []string,
) MonitorStatusSummary {
summary := MonitorStatusSummary{ExtraModels: make([]ExtraModelStatus, 0, len(extras))}
if primary != "" {
if l, ok := latestByModel[primary]; ok {
summary.PrimaryStatus = l.Status
summary.PrimaryLatencyMs = l.LatencyMs
}
if a, ok := availByModel[primary]; ok {
summary.Availability7d = a.AvailabilityPct
}
}
for _, model := range extras {
entry := ExtraModelStatus{Model: model}
if l, ok := latestByModel[model]; ok {
entry.Status = l.Status
entry.LatencyMs = l.LatencyMs
}
summary.ExtraModels = append(summary.ExtraModels, entry)
}
return summary
}
// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。
// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。
func buildUserViewFromSummary(
m *ChannelMonitor,
summary MonitorStatusSummary,
primaryLatest *ChannelMonitorLatest,
timelineEntries []*ChannelMonitorHistoryEntry,
) *UserMonitorView {
view := &UserMonitorView{
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
GroupName: m.GroupName,
PrimaryModel: m.PrimaryModel,
PrimaryStatus: summary.PrimaryStatus,
PrimaryLatencyMs: summary.PrimaryLatencyMs,
Availability7d: summary.Availability7d,
ExtraModels: summary.ExtraModels,
Timeline: buildTimelinePoints(timelineEntries),
}
if primaryLatest != nil {
view.PrimaryPingLatencyMs = primaryLatest.PingLatencyMs
}
return view
}
// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。
func buildTimelinePoints(entries []*ChannelMonitorHistoryEntry) []UserMonitorTimelinePoint {
out := make([]UserMonitorTimelinePoint, 0, len(entries))
for _, e := range entries {
out = append(out, UserMonitorTimelinePoint{
Status: e.Status,
LatencyMs: e.LatencyMs,
PingLatencyMs: e.PingLatencyMs,
CheckedAt: e.CheckedAt,
})
}
return out
}
// mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。
// 复用 indexLatestByModel,避免在多处重复写 build map 逻辑。
func mergeModelDetails(
m *ChannelMonitor,
latest []*ChannelMonitorLatest,
availMap map[int]map[string]*ChannelMonitorAvailability,
) []ModelDetail {
all := append([]string{m.PrimaryModel}, m.ExtraModels...)
latestByModel := indexLatestByModel(latest)
out := make([]ModelDetail, 0, len(all))
for _, model := range all {
d := ModelDetail{Model: model}
if l, ok := latestByModel[model]; ok {
d.LatestStatus = l.Status
d.LatestLatencyMs = l.LatencyMs
}
if a, ok := availMap[monitorAvailability7Days][model]; ok {
d.Availability7d = a.AvailabilityPct
d.AvgLatency7dMs = a.AvgLatencyMs
}
if a, ok := availMap[monitorAvailability15Days][model]; ok {
d.Availability15d = a.AvailabilityPct
}
if a, ok := availMap[monitorAvailability30Days][model]; ok {
d.Availability30d = a.AvailabilityPct
}
out = append(out, d)
}
return out
}
package service
import (
"fmt"
"math/rand/v2"
"regexp"
"strconv"
)
// monitorChallengePromptTemplate 1:1 复刻 BingZi-233/check-cx 的 few-shot 模板。
const monitorChallengePromptTemplate = `Calculate and respond with ONLY the number, nothing else.
Q: 3 + 5 = ?
A: 8
Q: 12 - 7 = ?
A: 5
Q: %d %s %d = ?
A:`
// monitorChallengeNumberRegex 提取响应中的所有整数(含负号)。
var monitorChallengeNumberRegex = regexp.MustCompile(`-?\d+`)
// monitorChallenge 一次 challenge 的 prompt + 期望答案。
type monitorChallenge struct {
Prompt string
Expected string
}
// generateChallenge 生成一次随机算术 challenge:
// - 随机两个 [monitorChallengeMin, monitorChallengeMax] 整数
// - 50% 加 / 50% 减;减法用 max - min 保证非负
// - 渲染 few-shot 模板
//
// 不强求加密随机:math/rand/v2 足够分散,避免 crypto/rand 的开销。
func generateChallenge() monitorChallenge {
a := randIntInRange(monitorChallengeMin, monitorChallengeMax)
b := randIntInRange(monitorChallengeMin, monitorChallengeMax)
if rand.IntN(2) == 0 { //nolint:gosec // 仅用于生成测试问题,无安全影响
// 加法
return monitorChallenge{
Prompt: fmt.Sprintf(monitorChallengePromptTemplate, a, "+", b),
Expected: strconv.Itoa(a + b),
}
}
// 减法,保证非负
hi, lo := a, b
if lo > hi {
hi, lo = lo, hi
}
return monitorChallenge{
Prompt: fmt.Sprintf(monitorChallengePromptTemplate, hi, "-", lo),
Expected: strconv.Itoa(hi - lo),
}
}
// randIntInRange 返回 [min, max] 闭区间的随机整数。
func randIntInRange(minVal, maxVal int) int {
if maxVal <= minVal {
return minVal
}
return minVal + rand.IntN(maxVal-minVal+1) //nolint:gosec
}
// validateChallenge 在响应文本中查找 expected 整数答案,返回是否通过校验。
func validateChallenge(responseText, expected string) bool {
if responseText == "" || expected == "" {
return false
}
matches := monitorChallengeNumberRegex.FindAllString(responseText, -1)
for _, m := range matches {
if m == expected {
return true
}
}
return false
}
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/tidwall/gjson"
)
// monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。
// 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。
var monitorHTTPClient = newSSRFSafeHTTPClient(monitorRequestTimeout)
// monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。
var monitorPingHTTPClient = newSSRFSafeHTTPClient(monitorPingTimeout)
// newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。
// 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。
func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
tr := &http.Transport{
DialContext: safeDialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 16,
IdleConnTimeout: monitorIdleConnTimeout,
TLSHandshakeTimeout: monitorTLSHandshakeTimeout,
ResponseHeaderTimeout: monitorResponseHeaderTimeout,
}
return &http.Client{Timeout: timeout, Transport: tr}
}
// CheckOptions 承载一次检测的自定义入参。
// 所有字段都是可选(零值即等价于"用默认行为")。
type CheckOptions struct {
// ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
ExtraHeaders map[string]string
// BodyOverrideMode: off | merge | replace
BodyOverrideMode string
// BodyOverride 在 merge 模式下做浅合并(key 命中黑名单时静默丢弃),
// 在 replace 模式下直接当作完整 body。
BodyOverride map[string]any
}
// runCheckForModel 对单个 (provider, model) 做一次完整检测。
// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。
//
// opts 承载模板 / 监控快照带来的自定义配置。nil 等同于 "off + 无 extra headers"。
func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string, opts *CheckOptions) *CheckResult {
res := &CheckResult{
Model: model,
Status: MonitorStatusError,
CheckedAt: time.Now(),
}
challenge := generateChallenge()
mode := bodyOverrideMode(opts)
start := time.Now()
respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt, opts)
latency := time.Since(start)
latencyMs := int(latency / time.Millisecond)
res.LatencyMs = &latencyMs
if err != nil {
res.Status = MonitorStatusError
res.Message = truncateMessage(sanitizeErrorMessage(err.Error()))
return res
}
if statusCode < 200 || statusCode >= 300 {
// 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空,
// 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。
res.Status = MonitorStatusError
bodySnippet := truncateForErrorBody(rawBody)
res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, bodySnippet)))
return res
}
// Replace 模式:跳过 challenge 校验(用户 body 是静态的,challenge 没法嵌入)。
// 改用「HTTP 2xx + 响应文本(adapter.textPath 抽取)非空」作为 operational 判定。
// 响应文本为空则降级为 failed(视为上游回了 200 但没实际内容)。
if mode == MonitorBodyOverrideModeReplace {
if strings.TrimSpace(respText) == "" {
res.Status = MonitorStatusFailed
res.Message = truncateMessage("replace-mode: upstream returned 2xx with empty text")
return res
}
return finalizeOperationalOrDegraded(res, latency, latencyMs)
}
if !validateChallenge(respText, challenge.Expected) {
res.Status = MonitorStatusFailed
res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText)))
return res
}
return finalizeOperationalOrDegraded(res, latency, latencyMs)
}
// finalizeOperationalOrDegraded 负责走到最后一步的 operational/degraded 判定。
// 拆出来是为了让 runCheckForModel 不超过 30 行。
func finalizeOperationalOrDegraded(res *CheckResult, latency time.Duration, latencyMs int) *CheckResult {
if latency >= monitorDegradedThreshold {
res.Status = MonitorStatusDegraded
res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs))
return res
}
res.Status = MonitorStatusOperational
return res
}
// bodyOverrideMode 归一取 opts.BodyOverrideMode,nil opts / 空串都视为 off。
func bodyOverrideMode(opts *CheckOptions) string {
if opts == nil || opts.BodyOverrideMode == "" {
return MonitorBodyOverrideModeOff
}
return opts.BodyOverrideMode
}
// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。
// 失败时返回 nil(不影响主状态判定)。
func pingEndpointOrigin(ctx context.Context, endpoint string) *int {
origin, err := extractOrigin(endpoint)
if err != nil || origin == "" {
return nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodHead, origin, nil)
if err != nil {
return nil
}
start := time.Now()
resp, err := monitorPingHTTPClient.Do(req)
if err != nil {
return nil
}
defer func() { _ = resp.Body.Close() }()
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, monitorPingDiscardMaxBytes))
ms := int(time.Since(start) / time.Millisecond)
return &ms
}
// providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事:
// - 拼出请求路径(含 model 占位)
// - 序列化请求体
// - 构造鉴权头
// - 从响应 JSON 中按 path 提取文本(gjson path)
//
// 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。
type providerAdapter struct {
buildPath func(model string) string
buildBody func(model, prompt string) ([]byte, error)
buildHeaders func(apiKey string) map[string]string
textPath string // gjson 提取响应文本的 path
}
// providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。
//
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var providerAdapters = map[string]providerAdapter{
MonitorProviderOpenAI: {
buildPath: func(string) string { return providerOpenAIPath },
buildBody: func(model, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"model": model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"max_tokens": monitorChallengeMaxTokens,
"stream": false,
})
},
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{"Authorization": "Bearer " + apiKey}
},
textPath: "choices.0.message.content",
},
MonitorProviderAnthropic: {
buildPath: func(string) string { return providerAnthropicPath },
buildBody: func(model, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"model": model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"max_tokens": monitorChallengeMaxTokens,
})
},
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{
"x-api-key": apiKey,
"anthropic-version": monitorAnthropicAPIVersion,
}
},
textPath: "content.0.text",
},
MonitorProviderGemini: {
// Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent
buildPath: func(model string) string { return fmt.Sprintf(providerGeminiPathTemplate, model) },
buildBody: func(_, prompt string) ([]byte, error) {
return json.Marshal(map[string]any{
"contents": []map[string]any{
{"parts": []map[string]any{{"text": prompt}}},
},
"generationConfig": map[string]any{"maxOutputTokens": monitorChallengeMaxTokens},
})
},
// 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。
buildHeaders: func(apiKey string) map[string]string {
return map[string]string{"x-goog-api-key": apiKey}
},
textPath: "candidates.0.content.parts.0.text",
},
}
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
func isSupportedProvider(p string) bool {
_, ok := providerAdapters[p]
return ok
}
// callProvider 通过 providerAdapters 分发到具体实现。
// opts 承载用户的自定义 headers / body 覆盖(可为 nil)。
//
// 返回值:
// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
// - status: HTTP 状态码
// - err: 网络 / 序列化错误
func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
adapter, ok := providerAdapters[provider]
if !ok {
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
}
body, err := buildRequestBody(adapter, provider, model, prompt, opts)
if err != nil {
return "", "", 0, err
}
headers := mergeHeaders(adapter.buildHeaders(apiKey), opts)
full := joinURL(endpoint, adapter.buildPath(model))
respBytes, status, err := postRawJSON(ctx, full, body, headers)
if err != nil {
return "", "", status, err
}
return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
}
// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string {
if opts == nil || len(opts.ExtraHeaders) == 0 {
return base
}
out := make(map[string]string, len(base)+len(opts.ExtraHeaders))
for k, v := range base {
out[k] = v
}
for k, v := range opts.ExtraHeaders {
if IsForbiddenHeaderName(k) {
continue
}
out[k] = v
}
return out
}
// buildRequestBody 根据 body_override_mode 构造请求 body。
//
// - off: adapter 默认 body
// - merge: adapter 默认 body 与 BodyOverride 浅合并;BodyOverride 中命中
// bodyMergeKeyDenyList[provider] 的 key 会被静默丢弃,避免破坏 challenge / model 路由
// - replace: 直接 marshal BodyOverride 作为完整 body
//
// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
mode := bodyOverrideMode(opts)
if mode == MonitorBodyOverrideModeReplace {
if opts == nil || len(opts.BodyOverride) == 0 {
return nil, fmt.Errorf("replace mode: body_override is empty")
}
body, err := json.Marshal(opts.BodyOverride)
if err != nil {
return nil, fmt.Errorf("marshal body_override (replace): %w", err)
}
return body, nil
}
defaultBody, err := adapter.buildBody(model, prompt)
if err != nil {
return nil, fmt.Errorf("marshal default body: %w", err)
}
if mode != MonitorBodyOverrideModeMerge || opts == nil || len(opts.BodyOverride) == 0 {
return defaultBody, nil
}
var defaultMap map[string]any
if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
}
deny := bodyMergeKeyDenyList[provider]
for k, v := range opts.BodyOverride {
if deny[k] {
continue
}
defaultMap[k] = v
}
merged, err := json.Marshal(defaultMap)
if err != nil {
return nil, fmt.Errorf("marshal merged body: %w", err)
}
return merged, nil
}
// bodyMergeKeyDenyList 在 merge 模式下,禁止用户覆盖这些 provider-specific 的关键字段。
// 思路抄 check-cx 的 EXCLUDED_METADATA_KEYS:保护 challenge / model 路由不被用户误伤。
// 用户想动这些字段就用 replace 模式(已知会跳 challenge 校验)。
//
//nolint:gochecknoglobals // 静态查表,初始化后不变。
var bodyMergeKeyDenyList = map[string]map[string]bool{
MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
MonitorProviderAnthropic: {"model": true, "messages": true},
MonitorProviderGemini: {"contents": true},
}
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。
func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := monitorHTTPClient.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("do request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, monitorResponseMaxBytes))
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("read body: %w", err)
}
return respBody, resp.StatusCode, nil
}
// joinURL 把 base origin 与 path 拼成完整 URL。
// 容忍 base 末尾有/无斜杠,path 必带前导斜杠。
func joinURL(base, path string) string {
base = strings.TrimRight(base, "/")
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return base + path
}
// extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。
func extractOrigin(endpoint string) (string, error) {
u, err := url.Parse(endpoint)
if err != nil {
return "", err
}
if u.Scheme == "" || u.Host == "" {
return "", errors.New("endpoint missing scheme or host")
}
return u.Scheme + "://" + u.Host, nil
}
// monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数:
// key / api_key / api-key / access_token / token / authorization / x-api-key。
// 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。
var monitorSensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`)
// monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。
// 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。
var monitorAPIKeyPatterns = []struct {
pattern *regexp.Regexp
replace string
}{
// Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx
{regexp.MustCompile(`sk-ant-[A-Za-z0-9_-]{20,}`), "sk-ant-***REDACTED***"},
// OpenAI / Anthropic 通用 sk-: sk-xxxxxxx
{regexp.MustCompile(`sk-[A-Za-z0-9-]{20,}`), "sk-***REDACTED***"},
// Gemini / Google API Key:固定前缀 + 35 位
{regexp.MustCompile(`AIza[A-Za-z0-9_-]{35}`), "AIza***REDACTED***"},
// JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature
{regexp.MustCompile(`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`), "eyJ***REDACTED.JWT***"},
}
// sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。
// 处理两类来源:
// 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL)
// 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片
//
// 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广,
// 监控模块独立维护,避免互相耦合。
func sanitizeErrorMessage(msg string) string {
if msg == "" {
return msg
}
msg = monitorSensitiveQueryParamRegex.ReplaceAllString(msg, `${1}REDACTED`)
for _, p := range monitorAPIKeyPatterns {
msg = p.pattern.ReplaceAllString(msg, p.replace)
}
return msg
}
// truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。
func truncateMessage(msg string) string {
if len(msg) <= monitorMessageMaxBytes {
return msg
}
const ellipsis = "...(truncated)"
cutoff := monitorMessageMaxBytes - len(ellipsis)
if cutoff < 0 {
cutoff = 0
}
return msg[:cutoff] + ellipsis
}
// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内,
// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。
// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。
func truncateForErrorBody(body string) string {
body = strings.Join(strings.Fields(body), " ")
if len(body) <= monitorErrorBodySnippetMaxBytes {
return body
}
const ellipsis = "...(body truncated)"
cutoff := monitorErrorBodySnippetMaxBytes - len(ellipsis)
if cutoff < 0 {
cutoff = 0
}
return body[:cutoff] + ellipsis
}
//go:build unit
package service
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client,
// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。
func swapMonitorHTTPClient(t *testing.T) {
t.Helper()
orig := monitorHTTPClient
monitorHTTPClient = &http.Client{Timeout: 5 * time.Second}
t.Cleanup(func() { monitorHTTPClient = orig })
}
// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。
type captureHandler struct {
lastBody map[string]any
lastHeaders http.Header
respondText string // 写到 Anthropic content[0].text 里(校验用)
status int
}
func (h *captureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.lastHeaders = r.Header.Clone()
defer func() { _ = r.Body.Close() }()
var parsed map[string]any
_ = json.NewDecoder(r.Body).Decode(&parsed)
h.lastBody = parsed
if h.status == 0 {
h.status = 200
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(h.status)
// 构造 Anthropic 格式的响应:content[0].text = h.respondText
_ = json.NewEncoder(w).Encode(map[string]any{
"content": []map[string]any{
{"type": "text", "text": h.respondText},
},
})
}
func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
t.Helper()
swapMonitorHTTPClient(t)
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
return srv.URL
}
func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)
// 跑一次 off 模式(opts=nil),确认默认 body 行为未变
_ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", nil)
if h.lastBody["model"] != "claude-x" {
t.Errorf("default body should contain model=claude-x, got %v", h.lastBody["model"])
}
if _, ok := h.lastBody["messages"]; !ok {
t.Error("default body should contain messages")
}
if h.lastHeaders.Get("x-api-key") != "sk-fake" {
t.Errorf("expected adapter's x-api-key header, got %q", h.lastHeaders.Get("x-api-key"))
}
}
func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)
opts := &CheckOptions{
BodyOverrideMode: MonitorBodyOverrideModeMerge,
BodyOverride: map[string]any{
"system": "You are Claude Code...",
"max_tokens": float64(999), // 应该覆盖默认 50
"model": "hacked-model", // 应该被黑名单挡住,保留原 model
"messages": []any{}, // 同上,被挡
},
ExtraHeaders: map[string]string{
"User-Agent": "claude-cli/1.0",
"Content-Length": "999", // 黑名单
"x-custom": "ok",
},
}
_ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
if h.lastBody["system"] != "You are Claude Code..." {
t.Errorf("merge mode should inject system, got %v", h.lastBody["system"])
}
// max_tokens 覆盖生效
if mt, ok := h.lastBody["max_tokens"].(float64); !ok || mt != 999 {
t.Errorf("merge mode should override max_tokens to 999, got %v", h.lastBody["max_tokens"])
}
// model 在黑名单 — 应该保留默认值
if h.lastBody["model"] != "claude-x" {
t.Errorf("model should be protected by deny list, got %v", h.lastBody["model"])
}
// messages 在黑名单 — 应该保留默认值(非空)
msgs, _ := h.lastBody["messages"].([]any)
if len(msgs) == 0 {
t.Error("messages should be protected by deny list (kept default, non-empty)")
}
// header 合并
if h.lastHeaders.Get("User-Agent") != "claude-cli/1.0" {
t.Errorf("extra User-Agent should override, got %q", h.lastHeaders.Get("User-Agent"))
}
if h.lastHeaders.Get("x-custom") != "ok" {
t.Errorf("extra custom header should be present, got %q", h.lastHeaders.Get("x-custom"))
}
// Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。
// 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。
}
func TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped(t *testing.T) {
// replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里,
// 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational
h := &captureHandler{respondText: "any non-empty text"}
endpoint := setupFakeAnthropic(t, h)
userBody := map[string]any{
"model": "user-forced-model",
"messages": []any{map[string]any{"role": "user", "content": "hi"}},
"max_tokens": float64(10),
"system": "You are someone else",
}
opts := &CheckOptions{
BodyOverrideMode: MonitorBodyOverrideModeReplace,
BodyOverride: userBody,
}
res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
// 请求 body = 用户提供的原样
if h.lastBody["model"] != "user-forced-model" {
t.Errorf("replace mode should use user's model, got %v", h.lastBody["model"])
}
if h.lastBody["system"] != "You are someone else" {
t.Errorf("replace mode should use user's system, got %v", h.lastBody["system"])
}
// challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational
if res.Status != MonitorStatusOperational {
t.Errorf("replace mode with 2xx + non-empty text should be operational, got status=%s message=%q",
res.Status, res.Message)
}
}
func TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed(t *testing.T) {
h := &captureHandler{respondText: ""} // 上游 200 但 content[0].text 为空
endpoint := setupFakeAnthropic(t, h)
opts := &CheckOptions{
BodyOverrideMode: MonitorBodyOverrideModeReplace,
BodyOverride: map[string]any{"model": "x", "messages": []any{}},
}
res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
if res.Status != MonitorStatusFailed {
t.Errorf("replace mode with empty text should be failed, got status=%s", res.Status)
}
if !strings.Contains(res.Message, "replace-mode") {
t.Errorf("failure message should hint replace-mode, got %q", res.Message)
}
}
package service
import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// ChannelMonitor 全局常量。
// 这些是 MVP 阶段的硬编码值,按需可以提到 config 中。
const (
// monitorRequestTimeout 单次模型请求总超时(含 Body 读取)。
monitorRequestTimeout = 45 * time.Second
// monitorPingTimeout HEAD 请求 endpoint origin 的超时。
monitorPingTimeout = 8 * time.Second
// monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。
monitorDegradedThreshold = 6 * time.Second
// monitorHistoryRetentionDays 明细历史保留天数。
// 60s 默认间隔 * 30 天 ≈ 43200 行/monitor/model,一般部署总量 <= 2M 行,
// PG 无压力;所以直接保留完整明细一个月,可用率查询可以全走原始行不依赖聚合。
// 聚合表 channel_monitor_daily_rollups 仍然保留,作为长期历史回填/降级查询的兜底。
monitorHistoryRetentionDays = 30
// monitorRollupRetentionDays 日聚合保留天数。
// 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。
monitorRollupRetentionDays = 30
// monitorMaintenanceMaxDaysPerRun 单次维护任务最多聚合的天数。
// 用于限制首次上线回填(30 天)+ 少量余量,避免长事务。
monitorMaintenanceMaxDaysPerRun = 35
// monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。
monitorWorkerConcurrency = 5
// monitorStartupLoadTimeout Start 时一次性加载所有 enabled monitor 的总超时。
monitorStartupLoadTimeout = 10 * time.Second
// monitorMinIntervalSeconds / monitorMaxIntervalSeconds 用户配置的检测间隔上下限。
monitorMinIntervalSeconds = 15
monitorMaxIntervalSeconds = 3600
// monitorMessageMaxBytes message 字段最大字节数(与 schema/migration 一致)。
monitorMessageMaxBytes = 500
// monitorResponseMaxBytes 单次模型响应最大读取字节,防止 OOM。
monitorResponseMaxBytes = 64 * 1024
// monitorErrorBodySnippetMaxBytes 非 2xx 响应时保留上游 body 片段的最大字节数。
// 留 300 字节足够覆盖典型结构化错误(如 `{"error":{"message":"..."}}`),
// 又给 "upstream HTTP <status>: " 前缀留出余量,避免最终被 monitorMessageMaxBytes (500) 截得太狠。
monitorErrorBodySnippetMaxBytes = 300
// monitorChallengeMin / monitorChallengeMax challenge 操作数范围。
monitorChallengeMin = 1
monitorChallengeMax = 50
// providerOpenAIPath OpenAI Chat Completions 路径。
providerOpenAIPath = "/v1/chat/completions"
// providerAnthropicPath Anthropic Messages 路径。
providerAnthropicPath = "/v1/messages"
// providerGeminiPathTemplate Gemini generateContent 路径模板(含 model 占位)。
providerGeminiPathTemplate = "/v1beta/models/%s:generateContent"
// MonitorProviderOpenAI / Anthropic / Gemini provider 字符串常量(也是 ent enum 的实际值)。
MonitorProviderOpenAI = "openai"
MonitorProviderAnthropic = "anthropic"
MonitorProviderGemini = "gemini"
// MonitorStatusOperational 等监控状态字符串常量(与 ent enum 一致)。
MonitorStatusOperational = "operational"
MonitorStatusDegraded = "degraded"
MonitorStatusFailed = "failed"
MonitorStatusError = "error"
// monitorAvailability7Days / 15 / 30 用于聚合查询窗口。
monitorAvailability7Days = 7
monitorAvailability15Days = 15
monitorAvailability30Days = 30
// MonitorHistoryDefaultLimit 历史查询默认返回条数(handler 层共享)。
MonitorHistoryDefaultLimit = 100
// MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。
MonitorHistoryMaxLimit = 1000
// monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。
monitorTimelineMaxPoints = 60
// monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。
monitorEndpointResolveTimeout = 5 * time.Second
// ---- checker / runner 行为参数(消除 magic 值)----
// monitorAnthropicAPIVersion Anthropic Messages API 版本头。
monitorAnthropicAPIVersion = "2023-06-01"
// monitorChallengeMaxTokens 单次 challenge 请求的 max_tokens(足够回答个位数算术)。
monitorChallengeMaxTokens = 50
// monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。
monitorRunOneBuffer = 10 * time.Second
// monitorIdleConnTimeout HTTP transport 空闲连接关闭超时。
monitorIdleConnTimeout = 30 * time.Second
// monitorTLSHandshakeTimeout HTTP transport TLS 握手超时。
monitorTLSHandshakeTimeout = 10 * time.Second
// monitorResponseHeaderTimeout HTTP transport 等待响应头超时。
monitorResponseHeaderTimeout = 30 * time.Second
// monitorPingDiscardMaxBytes ping 时丢弃响应体的最大字节数。
monitorPingDiscardMaxBytes = 1024
// monitorDialTimeout 自定义 dialer 单次连接超时。
monitorDialTimeout = 10 * time.Second
// monitorDialKeepAlive 自定义 dialer keep-alive 间隔。
monitorDialKeepAlive = 30 * time.Second
)
// 业务错误(统一在此声明,避免散落)。
var (
ErrChannelMonitorNotFound = infraerrors.NotFound(
"CHANNEL_MONITOR_NOT_FOUND", "channel monitor not found",
)
ErrChannelMonitorInvalidProvider = infraerrors.BadRequest(
"CHANNEL_MONITOR_INVALID_PROVIDER", "provider must be one of openai/anthropic/gemini",
)
ErrChannelMonitorInvalidInterval = infraerrors.BadRequest(
"CHANNEL_MONITOR_INVALID_INTERVAL", "interval_seconds must be in [15, 3600]",
)
ErrChannelMonitorInvalidEndpoint = infraerrors.BadRequest(
"CHANNEL_MONITOR_INVALID_ENDPOINT", "endpoint must be a valid https URL",
)
ErrChannelMonitorEndpointScheme = infraerrors.BadRequest(
"CHANNEL_MONITOR_ENDPOINT_SCHEME", "endpoint must use https scheme",
)
ErrChannelMonitorEndpointPath = infraerrors.BadRequest(
"CHANNEL_MONITOR_ENDPOINT_PATH", "endpoint must be base origin only (no path/query/fragment)",
)
ErrChannelMonitorEndpointPrivate = infraerrors.BadRequest(
"CHANNEL_MONITOR_ENDPOINT_PRIVATE", "endpoint must be a public host",
)
ErrChannelMonitorEndpointUnreachable = infraerrors.BadRequest(
"CHANNEL_MONITOR_ENDPOINT_UNREACHABLE", "endpoint hostname could not be resolved",
)
ErrChannelMonitorMissingAPIKey = infraerrors.BadRequest(
"CHANNEL_MONITOR_MISSING_API_KEY", "api_key is required when creating a monitor",
)
ErrChannelMonitorMissingPrimaryModel = infraerrors.BadRequest(
"CHANNEL_MONITOR_MISSING_PRIMARY_MODEL", "primary_model is required",
)
ErrChannelMonitorAPIKeyDecryptFailed = infraerrors.InternalServer(
"CHANNEL_MONITOR_KEY_DECRYPT_FAILED", "api key decryption failed; please re-edit the monitor with a fresh key",
)
)
package service
import (
"context"
"log/slog"
"sync"
"time"
"github.com/alitto/pond/v2"
)
// MonitorScheduler 调度器接口,供 ChannelMonitorService 在 CRUD 时回调,
// 用 setter 注入避免 service ↔ runner 的 wire 依赖环。
type MonitorScheduler interface {
// Schedule 为指定监控创建(或重置)独立定时任务。
// 当 m.Enabled=false 时等同于 Unschedule(m.ID)。
Schedule(m *ChannelMonitor)
// Unschedule 取消指定监控的定时任务(若存在)。
Unschedule(id int64)
}
// monitorRunnerSvc 抽出 runner 实际依赖的两个 service 方法:
// - 启动时加载 enabled monitor
// - 每次 ticker 触发执行检测
//
// 用接口而非 *ChannelMonitorService 是为了让 runner 单元测试可注入轻量 stub,
// 避免依赖完整的 repo + encryptor 链路。生产实现 *ChannelMonitorService 自然满足。
type monitorRunnerSvc interface {
ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error)
RunCheck(ctx context.Context, id int64) ([]*CheckResult, error)
}
// ChannelMonitorRunner 渠道监控调度器。
//
// 设计:
// - 每个 enabled monitor 对应一个独立 goroutine + ticker(按各自 IntervalSeconds)
// - Start 时一次性加载所有 enabled monitor 并为每个建立任务
// - Service 在 Create/Update/Delete 后通过 MonitorScheduler 接口回调,
// 即时重建/取消对应任务(无需轮询 DB)
// - 实际 HTTP 检测交给 pond 池(容量 monitorWorkerConcurrency),
// 防止突发并发拖垮上游
//
// 历史清理与日聚合维护由 OpsCleanupService 的 cron 触发
// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat),
// 不在 runner 职责内。
type ChannelMonitorRunner struct {
svc monitorRunnerSvc
settingService *SettingService
pool pond.Pool
parentCtx context.Context
parentCancel context.CancelFunc
mu sync.Mutex
tasks map[int64]*scheduledMonitor
wg sync.WaitGroup
started bool
stopped bool
// inFlight 跟踪正在执行的 monitor.ID。fire 调度前会检查避免重复提交,
// 防止单次检测耗时 > interval 时同一 monitor 被并发执行。
inFlight map[int64]struct{}
inFlightMu sync.Mutex
}
// scheduledMonitor 单个监控的运行时上下文。
type scheduledMonitor struct {
id int64
name string
interval time.Duration
cancel context.CancelFunc
}
// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。
// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
//
// pool 在构造时即建好:避免 Start 在 mu 内赋值、fire/Stop 在 mu 外读取的竞态隐患,
// 且 pond.NewPool 创建本身近似零开销,提前建池不会浪费资源。
func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
return newChannelMonitorRunner(svc, settingService)
}
// newChannelMonitorRunner 内部构造,接受最小化接口,便于单元测试注入 stub。
func newChannelMonitorRunner(svc monitorRunnerSvc, settingService *SettingService) *ChannelMonitorRunner {
ctx, cancel := context.WithCancel(context.Background())
return &ChannelMonitorRunner{
svc: svc,
settingService: settingService,
pool: pond.NewPool(monitorWorkerConcurrency),
parentCtx: ctx,
parentCancel: cancel,
tasks: make(map[int64]*scheduledMonitor),
inFlight: make(map[int64]struct{}),
}
}
// Start 加载所有 enabled monitor 并为每个建立独立定时任务。
// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。
func (r *ChannelMonitorRunner) Start() {
if r == nil || r.svc == nil {
return
}
r.mu.Lock()
if r.started || r.stopped {
r.mu.Unlock()
return
}
r.started = true
r.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), monitorStartupLoadTimeout)
defer cancel()
enabled, err := r.svc.ListEnabledMonitors(ctx)
if err != nil {
slog.Error("channel_monitor: load enabled monitors failed at startup", "error", err)
return
}
for _, m := range enabled {
r.Schedule(m)
}
slog.Info("channel_monitor: runner started", "scheduled_tasks", len(enabled))
}
// Schedule 为指定监控创建(或重置)独立定时任务。
// - m.Enabled=false → 等同于 Unschedule(m.ID)
// - 已存在的任务会先被取消再重建(适用于 IntervalSeconds 变更场景)
// - 新任务立即触发首次检测,之后按 IntervalSeconds 周期触发
func (r *ChannelMonitorRunner) Schedule(m *ChannelMonitor) {
if r == nil || m == nil {
return
}
if !m.Enabled {
r.Unschedule(m.ID)
return
}
interval := time.Duration(m.IntervalSeconds) * time.Second
if interval <= 0 {
// Create/Update 已通过 validateInterval 校验区间,正常路径不可能到这里。
// 真触发说明数据库中存在违反约束的数据或校验链路有 bug,记 Error 暴露问题。
slog.Error("channel_monitor: skip schedule for invalid interval",
"monitor_id", m.ID, "interval_seconds", m.IntervalSeconds)
return
}
r.mu.Lock()
if r.stopped {
r.mu.Unlock()
return
}
if !r.started {
// Start 之前调用 Schedule 通常意味着 wire 顺序错乱:
// 当前 wire 顺序是 SetScheduler → Start,CRUD 钩子最早也只能在请求到达时触发,
// 此时 Start 早已完成。出现此分支时把 monitor 信息打出来便于排查,
// 不入队、不缓存——交给运维通过重启或修复 wire 解决。
r.mu.Unlock()
slog.Warn("channel_monitor: schedule before runner started, skip",
"monitor_id", m.ID, "name", m.Name)
return
}
if existing, ok := r.tasks[m.ID]; ok {
existing.cancel()
}
ctx, cancel := context.WithCancel(r.parentCtx)
task := &scheduledMonitor{
id: m.ID,
name: m.Name,
interval: interval,
cancel: cancel,
}
r.tasks[m.ID] = task
r.wg.Add(1)
r.mu.Unlock()
go r.runScheduled(ctx, task)
}
// Unschedule 取消指定监控的定时任务(若存在)。
// 已经在执行中的检测会通过 ctx 取消信号传递。
func (r *ChannelMonitorRunner) Unschedule(id int64) {
if r == nil {
return
}
r.mu.Lock()
task, ok := r.tasks[id]
if ok {
delete(r.tasks, id)
}
r.mu.Unlock()
if ok {
task.cancel()
}
}
// Stop 优雅停止:取消所有任务、关闭池。
func (r *ChannelMonitorRunner) Stop() {
if r == nil {
return
}
r.mu.Lock()
if r.stopped {
r.mu.Unlock()
return
}
r.stopped = true
r.parentCancel()
r.tasks = nil
r.mu.Unlock()
r.wg.Wait()
r.pool.StopAndWait()
}
// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"),
// 之后按 interval 周期触发;ctx 取消即退出。
func (r *ChannelMonitorRunner) runScheduled(ctx context.Context, task *scheduledMonitor) {
defer r.wg.Done()
r.fire(ctx, task)
ticker := time.NewTicker(task.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.fire(ctx, task)
}
}
}
// fire 提交一次检测到 worker 池。功能开关关闭时跳过本次(不取消任务,
// 重新启用时立即恢复);池满或重复在飞时也跳过。
func (r *ChannelMonitorRunner) fire(ctx context.Context, task *scheduledMonitor) {
if r.settingService != nil && !r.settingService.GetChannelMonitorRuntime(ctx).Enabled {
return
}
if !r.tryAcquireInFlight(task.id) {
slog.Debug("channel_monitor: skip already in-flight",
"monitor_id", task.id, "name", task.name)
return
}
if _, ok := r.pool.TrySubmit(func() {
r.runOne(task.id, task.name)
}); !ok {
// 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。
r.releaseInFlight(task.id)
slog.Warn("channel_monitor: worker pool full, skip submission",
"monitor_id", task.id, "name", task.name)
}
}
// tryAcquireInFlight 原子地占用 monitor 的 in-flight 槽。
// 已被占用返回 false(调用方应跳过本次提交)。
func (r *ChannelMonitorRunner) tryAcquireInFlight(id int64) bool {
r.inFlightMu.Lock()
defer r.inFlightMu.Unlock()
if _, exists := r.inFlight[id]; exists {
return false
}
r.inFlight[id] = struct{}{}
return true
}
// releaseInFlight 释放 in-flight 槽。runOne 完成(含 panic recover)后必须调用。
func (r *ChannelMonitorRunner) releaseInFlight(id int64) {
r.inFlightMu.Lock()
delete(r.inFlight, id)
r.inFlightMu.Unlock()
}
// runOne 执行单个监控的检测。所有错误只记日志,不熔断。
// 任务结束时(含 panic recover)必须释放 in-flight 槽。
func (r *ChannelMonitorRunner) runOne(id int64, name string) {
ctx, cancel := context.WithTimeout(context.Background(), monitorRequestTimeout+monitorPingTimeout+monitorRunOneBuffer)
defer cancel()
defer r.releaseInFlight(id)
defer func() {
if rec := recover(); rec != nil {
slog.Error("channel_monitor: runner panic",
"monitor_id", id, "name", name, "panic", rec)
}
}()
if _, err := r.svc.RunCheck(ctx, id); err != nil {
slog.Warn("channel_monitor: run check failed",
"monitor_id", id, "name", name, "error", err)
}
}
//go:build unit
package service
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
)
// stubMonitorSvc 实现 monitorRunnerSvc,用于隔离 runner 与真实 service/repo。
type stubMonitorSvc struct {
enabled []*ChannelMonitor
runCount atomic.Int64
runCalled chan int64 // 每次 RunCheck 触发时 push 一次(缓冲足够大避免阻塞)
runErr error
listErr error
runHoldFor time.Duration // RunCheck 内额外阻塞的时长,用来测试 Stop 等待行为
}
func (s *stubMonitorSvc) ListEnabledMonitors(_ context.Context) ([]*ChannelMonitor, error) {
if s.listErr != nil {
return nil, s.listErr
}
return s.enabled, nil
}
func (s *stubMonitorSvc) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) {
s.runCount.Add(1)
if s.runCalled != nil {
select {
case s.runCalled <- id:
default:
}
}
if s.runHoldFor > 0 {
select {
case <-time.After(s.runHoldFor):
case <-ctx.Done():
}
}
return nil, s.runErr
}
func newRunnerForTest(svc monitorRunnerSvc) *ChannelMonitorRunner {
return newChannelMonitorRunner(svc, nil)
}
// 等待 condition 在 timeout 内变 true,否则 t.Fatalf。轮询 5ms 一次。
func waitFor(t *testing.T, timeout time.Duration, msg string, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(5 * time.Millisecond)
}
if !cond() {
t.Fatalf("waitFor timed out: %s", msg)
}
}
func runnerTaskCount(r *ChannelMonitorRunner) int {
r.mu.Lock()
defer r.mu.Unlock()
return len(r.tasks)
}
func runnerTaskPtr(r *ChannelMonitorRunner, id int64) *scheduledMonitor {
r.mu.Lock()
defer r.mu.Unlock()
return r.tasks[id]
}
// TestSchedule_AddsTaskAndFiresOnce 验证 Schedule 后立即触发一次首检测,并把任务记入 tasks 表。
func TestSchedule_AddsTaskAndFiresOnce(t *testing.T) {
svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
r := newRunnerForTest(svc)
r.Start() // svc.enabled 为空,Start 立即完成
r.Schedule(&ChannelMonitor{ID: 1, Name: "m1", Enabled: true, IntervalSeconds: 60})
if got := runnerTaskCount(r); got != 1 {
t.Fatalf("expected 1 scheduled task, got %d", got)
}
select {
case id := <-svc.runCalled:
if id != 1 {
t.Fatalf("expected first fire for id=1, got %d", id)
}
case <-time.After(2 * time.Second):
t.Fatal("expected immediate first fire within 2s")
}
r.Stop()
}
// TestSchedule_ReplaceCancelsOldTask 验证对同一 id 二次 Schedule 会替换旧 task 实例。
// (旧 goroutine 通过 ctx 取消退出;这里以 task 指针不同 + Stop 不超时作为证据。)
func TestSchedule_ReplaceCancelsOldTask(t *testing.T) {
svc := &stubMonitorSvc{runCalled: make(chan int64, 8)}
r := newRunnerForTest(svc)
r.Start()
m := &ChannelMonitor{ID: 7, Name: "m7", Enabled: true, IntervalSeconds: 60}
r.Schedule(m)
first := runnerTaskPtr(r, 7)
if first == nil {
t.Fatal("first schedule did not register task")
}
r.Schedule(m)
second := runnerTaskPtr(r, 7)
if second == nil {
t.Fatal("second schedule did not register task")
}
if first == second {
t.Fatal("re-Schedule should create a new scheduledMonitor instance")
}
stoppedWithin(t, r, 3*time.Second)
}
// TestUnschedule_RemovesTask 验证 Unschedule 删除 task 并使对应 goroutine 退出。
func TestUnschedule_RemovesTask(t *testing.T) {
svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
r := newRunnerForTest(svc)
r.Start()
r.Schedule(&ChannelMonitor{ID: 3, Enabled: true, IntervalSeconds: 60})
waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
r.Unschedule(3)
if got := runnerTaskCount(r); got != 0 {
t.Fatalf("expected tasks empty after Unschedule, got %d", got)
}
stoppedWithin(t, r, 3*time.Second)
}
// TestSchedule_DisabledRedirectsToUnschedule 验证 Enabled=false 等同于 Unschedule。
func TestSchedule_DisabledRedirectsToUnschedule(t *testing.T) {
svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
r := newRunnerForTest(svc)
r.Start()
r.Schedule(&ChannelMonitor{ID: 9, Enabled: true, IntervalSeconds: 60})
waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
r.Schedule(&ChannelMonitor{ID: 9, Enabled: false, IntervalSeconds: 60})
if got := runnerTaskCount(r); got != 0 {
t.Fatalf("expected tasks empty after disabled re-Schedule, got %d", got)
}
stoppedWithin(t, r, 3*time.Second)
}
// TestSchedule_InvalidIntervalSkipped 验证 IntervalSeconds<=0 不会注册任务(防御性检查)。
func TestSchedule_InvalidIntervalSkipped(t *testing.T) {
svc := &stubMonitorSvc{}
r := newRunnerForTest(svc)
r.Start()
r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 0})
if got := runnerTaskCount(r); got != 0 {
t.Fatalf("expected no task for invalid interval, got %d", got)
}
r.Stop()
}
// TestSchedule_BeforeStartIsNoOp 验证 Start 之前调用 Schedule 不会注册任务。
func TestSchedule_BeforeStartIsNoOp(t *testing.T) {
svc := &stubMonitorSvc{}
r := newRunnerForTest(svc)
// 故意不调用 Start
r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
if got := runnerTaskCount(r); got != 0 {
t.Fatalf("expected no task before Start, got %d", got)
}
r.Stop()
}
// TestStart_LoadsAllEnabledMonitors 验证 Start 会为 ListEnabledMonitors 返回的每条记录建立任务。
func TestStart_LoadsAllEnabledMonitors(t *testing.T) {
svc := &stubMonitorSvc{
enabled: []*ChannelMonitor{
{ID: 1, Enabled: true, IntervalSeconds: 60},
{ID: 2, Enabled: true, IntervalSeconds: 60},
{ID: 3, Enabled: true, IntervalSeconds: 60},
},
}
r := newRunnerForTest(svc)
r.Start()
waitFor(t, 2*time.Second, "all 3 tasks scheduled", func() bool { return runnerTaskCount(r) == 3 })
stoppedWithin(t, r, 3*time.Second)
}
// TestStop_DrainsAllGoroutines 验证 Stop 会等待所有调度 goroutine 退出(无游离)。
func TestStop_DrainsAllGoroutines(t *testing.T) {
svc := &stubMonitorSvc{}
r := newRunnerForTest(svc)
r.Start()
for id := int64(1); id <= 5; id++ {
r.Schedule(&ChannelMonitor{ID: id, Enabled: true, IntervalSeconds: 60})
}
waitFor(t, 2*time.Second, "5 tasks scheduled", func() bool { return runnerTaskCount(r) == 5 })
stoppedWithin(t, r, 3*time.Second)
}
// TestStop_WaitsForInFlightCheck 验证 Stop 会等待正在执行的 RunCheck 退出(pool.StopAndWait)。
func TestStop_WaitsForInFlightCheck(t *testing.T) {
svc := &stubMonitorSvc{
runCalled: make(chan int64, 1),
runHoldFor: 200 * time.Millisecond,
}
r := newRunnerForTest(svc)
r.Start()
r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
select {
case <-svc.runCalled:
case <-time.After(2 * time.Second):
t.Fatal("first fire never happened")
}
start := time.Now()
stoppedWithin(t, r, 3*time.Second)
elapsed := time.Since(start)
// Stop 必须等待 in-flight check 跑完(runHoldFor=200ms),耗时下界约 100ms。
if elapsed < 100*time.Millisecond {
t.Fatalf("Stop returned too fast (%v); did not wait for in-flight check", elapsed)
}
}
// TestInFlight_PoolFullReleasesSlot 直接驱动 fire 路径,模拟 pool.TrySubmit 失败时 inFlight 必须释放。
// 用一个小型 stub pool 替换 r.pool 不便(pond.Pool 是接口但 mock 麻烦),
// 改为:占满 inFlight 后直接 fire,验证不会在 inFlight 空槽时永久卡住。
func TestInFlight_AcquireReleaseSymmetric(t *testing.T) {
svc := &stubMonitorSvc{}
r := newRunnerForTest(svc)
if !r.tryAcquireInFlight(42) {
t.Fatal("first acquire should succeed")
}
if r.tryAcquireInFlight(42) {
t.Fatal("second acquire (no release) must fail")
}
r.releaseInFlight(42)
if !r.tryAcquireInFlight(42) {
t.Fatal("acquire after release should succeed")
}
r.releaseInFlight(42)
}
// stoppedWithin 在 timeout 内并行调用 Stop,超时则 Fatal。验证 Stop 不会阻塞。
func stoppedWithin(t *testing.T, r *ChannelMonitorRunner, timeout time.Duration) {
t.Helper()
done := make(chan struct{})
var once sync.Once
go func() {
r.Stop()
once.Do(func() { close(done) })
}()
select {
case <-done:
case <-time.After(timeout):
t.Fatalf("Stop did not return within %s — leaked goroutine?", timeout)
}
}
package service
import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
)
// ChannelMonitorRepository 渠道监控数据访问接口。
// 入参/返回的指针类型均使用 service 包的 ChannelMonitor 模型,
// repository 实现负责与 ent 模型互转,并保持 api_key_encrypted 字段为密文。
type ChannelMonitorRepository interface {
// CRUD
Create(ctx context.Context, m *ChannelMonitor) error
GetByID(ctx context.Context, id int64) (*ChannelMonitor, error)
Update(ctx context.Context, m *ChannelMonitor) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error)
// 调度器辅助
ListEnabled(ctx context.Context) ([]*ChannelMonitor, error)
MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error
InsertHistoryBatch(ctx context.Context, rows []*ChannelMonitorHistoryRow) error
DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error)
// 历史记录
ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error)
// 用户视图聚合
ListLatestPerModel(ctx context.Context, monitorID int64) ([]*ChannelMonitorLatest, error)
ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*ChannelMonitorAvailability, error)
// 批量聚合(admin/user list 用,避免 N+1)
ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*ChannelMonitorLatest, error)
ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*ChannelMonitorAvailability, error)
// ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。
// 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。
ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error)
// ---------- 聚合维护(OpsCleanupService 调用) ----------
// UpsertDailyRollupsFor 把 targetDate 当天的明细按 (monitor_id, model, bucket_date)
// 聚合到 channel_monitor_daily_rollups。targetDate 会被截断到日期;
// 用 ON CONFLICT DO UPDATE 实现幂等回填,返回 upsert 影响的行数。
UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error)
// DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行,返回删除行数。
DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error)
// LoadAggregationWatermark 读 watermark(id=1)。
// 返回 nil 表示从未聚合过;watermark 表本身预期已存在单行(migration 110 写入)。
LoadAggregationWatermark(ctx context.Context) (*time.Time, error)
// UpdateAggregationWatermark 写 watermark(UPSERT 到 id=1)。
UpdateAggregationWatermark(ctx context.Context, date time.Time) error
}
// ChannelMonitorService 渠道监控管理服务。
type ChannelMonitorService struct {
repo ChannelMonitorRepository
encryptor SecretEncryptor
// scheduler 由 wire 通过 SetScheduler 注入;CRUD 后调用对应钩子即时同步任务。
// 测试或未注入场景下保持 nil,所有钩子调用变为 no-op。
scheduler MonitorScheduler
}
// NewChannelMonitorService 创建渠道监控服务实例。
func NewChannelMonitorService(repo ChannelMonitorRepository, encryptor SecretEncryptor) *ChannelMonitorService {
return &ChannelMonitorService{repo: repo, encryptor: encryptor}
}
// ---------- CRUD ----------
// List 列表查询(支持 provider/enabled/search 过滤 + 分页)。
// 返回的 ChannelMonitor.APIKey 已解密为明文,handler 层负责脱敏。
func (s *ChannelMonitorService) List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error) {
if params.Page < 1 {
params.Page = 1
}
if params.PageSize < 1 || params.PageSize > 200 {
params.PageSize = 20
}
items, total, err := s.repo.List(ctx, params)
if err != nil {
return nil, 0, fmt.Errorf("list channel monitors: %w", err)
}
for _, it := range items {
s.decryptInPlace(it)
}
return items, total, nil
}
// Get 查询单个监控(解密 API Key)。
func (s *ChannelMonitorService) Get(ctx context.Context, id int64) (*ChannelMonitor, error) {
m, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, err
}
s.decryptInPlace(m)
return m, nil
}
// Create 创建监控(内部加密 api_key)。
func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCreateParams) (*ChannelMonitor, error) {
if err := validateCreateParams(p); err != nil {
return nil, err
}
if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
return nil, err
}
if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
return nil, err
}
encrypted, err := s.encryptor.Encrypt(p.APIKey)
if err != nil {
return nil, fmt.Errorf("encrypt api key: %w", err)
}
m := &ChannelMonitor{
Name: strings.TrimSpace(p.Name),
Provider: p.Provider,
Endpoint: normalizeEndpoint(p.Endpoint),
APIKey: encrypted, // 注意:传入 repository 时该字段为密文
PrimaryModel: strings.TrimSpace(p.PrimaryModel),
ExtraModels: normalizeModels(p.ExtraModels),
GroupName: strings.TrimSpace(p.GroupName),
Enabled: p.Enabled,
IntervalSeconds: p.IntervalSeconds,
CreatedBy: p.CreatedBy,
TemplateID: p.TemplateID,
ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
BodyOverride: p.BodyOverride,
}
if err := s.repo.Create(ctx, m); err != nil {
return nil, fmt.Errorf("create channel monitor: %w", err)
}
// 不再调 s.Get 重走解密链:已知刚加密的明文,直接构造响应。
// 这样可避免 SecretEncryptor 解密失败时 APIKey 被静默清空的问题(见 Fix 4)。
m.APIKey = strings.TrimSpace(p.APIKey)
if s.scheduler != nil {
s.scheduler.Schedule(m)
}
return m, nil
}
// validateCreateParams 把 Create 入参的所有校验聚拢为一个函数,避免 Create 主体超过 30 行。
func validateCreateParams(p ChannelMonitorCreateParams) error {
if err := validateProvider(p.Provider); err != nil {
return err
}
if err := validateInterval(p.IntervalSeconds); err != nil {
return err
}
if err := validateEndpoint(p.Endpoint); err != nil {
return err
}
if strings.TrimSpace(p.APIKey) == "" {
return ErrChannelMonitorMissingAPIKey
}
if strings.TrimSpace(p.PrimaryModel) == "" {
return ErrChannelMonitorMissingPrimaryModel
}
return nil
}
// Update 更新监控。APIKey 字段:nil 或空字符串 = 不修改;非空 = 加密后覆盖。
func (s *ChannelMonitorService) Update(ctx context.Context, id int64, p ChannelMonitorUpdateParams) (*ChannelMonitor, error) {
existing, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if err := applyMonitorUpdate(existing, p); err != nil {
return nil, err
}
newPlainAPIKey, apiKeyUpdated, err := s.applyAPIKeyUpdate(existing, p.APIKey)
if err != nil {
return nil, err
}
if err := s.repo.Update(ctx, existing); err != nil {
return nil, fmt.Errorf("update channel monitor: %w", err)
}
// 不再调 s.Get 重走解密链:避免二次解密带来的"密文被静默清空"风险(与 Create 一致)。
if apiKeyUpdated {
existing.APIKey = newPlainAPIKey
} else {
s.decryptInPlace(existing)
}
if s.scheduler != nil {
// Schedule 内部根据 Enabled 自动选择 Unschedule 或重建任务,
// IntervalSeconds 变化也会被自然吸收(旧 task 取消 + 新 task 用新 interval)。
s.scheduler.Schedule(existing)
}
return existing, nil
}
// applyAPIKeyUpdate 处理 Update 中的 APIKey 字段:
// - 入参 raw 为 nil 或空白:不修改 existing.APIKey(仍为密文),返回 updated=false
// - 非空:加密后写入 existing.APIKey;同时把明文返回给调用方,
// 供写库成功后塞回 existing 避免把密文吐回客户端
func (s *ChannelMonitorService) applyAPIKeyUpdate(existing *ChannelMonitor, raw *string) (plain string, updated bool, err error) {
if raw == nil || strings.TrimSpace(*raw) == "" {
return "", false, nil
}
plain = strings.TrimSpace(*raw)
encrypted, encErr := s.encryptor.Encrypt(plain)
if encErr != nil {
return "", false, fmt.Errorf("encrypt api key: %w", encErr)
}
existing.APIKey = encrypted
return plain, true, nil
}
// Delete 删除监控(历史通过外键 CASCADE 自动清理)。
func (s *ChannelMonitorService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete channel monitor: %w", err)
}
if s.scheduler != nil {
s.scheduler.Unschedule(id)
}
return nil
}
// ListHistory 列出某个监控最近的检测历史。
// model 为空表示返回所有模型;limit <= 0 时使用默认值,超过上限会被截断。
func (s *ChannelMonitorService) ListHistory(ctx context.Context, id int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error) {
if _, err := s.repo.GetByID(ctx, id); err != nil {
return nil, err
}
if limit <= 0 {
limit = MonitorHistoryDefaultLimit
}
if limit > MonitorHistoryMaxLimit {
limit = MonitorHistoryMaxLimit
}
entries, err := s.repo.ListHistory(ctx, id, strings.TrimSpace(model), limit)
if err != nil {
return nil, fmt.Errorf("list history: %w", err)
}
return entries, nil
}
// ---------- 业务 ----------
// RunCheck 同步触发对一个监控的检测:并发跑 primary + extra 模型,
// 写历史记录并更新 last_checked_at。返回每个模型的检测结果。
func (s *ChannelMonitorService) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) {
m, err := s.Get(ctx, id) // 已解密 APIKey
if err != nil {
return nil, err
}
if m.APIKeyDecryptFailed {
return nil, ErrChannelMonitorAPIKeyDecryptFailed
}
results := s.runChecksConcurrent(ctx, m)
s.persistCheckResults(ctx, m, results)
return results, nil
}
// persistCheckResults 写入本次检测的历史记录并更新 last_checked_at。
// 任一写库失败都只记日志,不影响调用方拿到 results(与 MVP 期望一致:宁可漏记历史也要先返回结果)。
func (s *ChannelMonitorService) persistCheckResults(ctx context.Context, m *ChannelMonitor, results []*CheckResult) {
rows := make([]*ChannelMonitorHistoryRow, 0, len(results))
for _, r := range results {
rows = append(rows, &ChannelMonitorHistoryRow{
MonitorID: m.ID,
Model: r.Model,
Status: r.Status,
LatencyMs: r.LatencyMs,
PingLatencyMs: r.PingLatencyMs,
Message: r.Message,
CheckedAt: r.CheckedAt,
})
}
if err := s.repo.InsertHistoryBatch(ctx, rows); err != nil {
slog.Error("channel_monitor: insert history failed",
"monitor_id", m.ID, "name", m.Name, "error", err)
}
if err := s.repo.MarkChecked(ctx, m.ID, time.Now()); err != nil {
slog.Error("channel_monitor: mark checked failed",
"monitor_id", m.ID, "error", err)
}
}
// runChecksConcurrent 对 primary + extra 模型并发执行检测。
// errgroup 仅用于等待,不传播错误(每个 model 失败都已打包进 CheckResult)。
func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *ChannelMonitor) []*CheckResult {
models := append([]string{m.PrimaryModel}, m.ExtraModels...)
results := make([]*CheckResult, len(models))
// ping 共享一次,所有模型记录同一个 ping 延迟。
pingMs := pingEndpointOrigin(ctx, m.Endpoint)
// 所有模型共用同一份 CheckOptions(来自监控的快照字段)。
opts := &CheckOptions{
ExtraHeaders: m.ExtraHeaders,
BodyOverrideMode: m.BodyOverrideMode,
BodyOverride: m.BodyOverride,
}
var eg errgroup.Group
var mu sync.Mutex
for i, model := range models {
i, model := i, model
eg.Go(func() error {
r := runCheckForModel(ctx, m.Provider, m.Endpoint, m.APIKey, model, opts)
r.PingLatencyMs = pingMs
mu.Lock()
results[i] = r
mu.Unlock()
return nil
})
}
_ = eg.Wait()
return results
}
// ---------- 调度器协作 ----------
// SetScheduler 由 wire 在 runner 构造后注入,用于在 CRUD 时即时同步任务表。
// 通过 setter 注入避免 service ↔ runner 的依赖环。
func (s *ChannelMonitorService) SetScheduler(sched MonitorScheduler) {
s.scheduler = sched
}
// ListEnabledMonitors 返回所有 enabled=true 的监控(解密后),供 runner 启动时建立任务表。
func (s *ChannelMonitorService) ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error) {
all, err := s.repo.ListEnabled(ctx)
if err != nil {
return nil, err
}
for _, m := range all {
s.decryptInPlace(m)
}
return all, nil
}
// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。
// 由 RunDailyMaintenance 调用;SoftDeleteMixin 自动把 DELETE 改为 UPDATE deleted_at。
func (s *ChannelMonitorService) cleanupOldHistory(ctx context.Context) error {
before := time.Now().UTC().AddDate(0, 0, -monitorHistoryRetentionDays)
deleted, err := s.repo.DeleteHistoryBefore(ctx, before)
if err != nil {
return fmt.Errorf("delete history before %s: %w", before.Format(time.RFC3339), err)
}
if deleted > 0 {
slog.Info("channel_monitor: history cleanup",
"deleted_rows", deleted, "before", before.Format(time.RFC3339))
}
return nil
}
// RunDailyMaintenance 每日维护任务:聚合昨天之前未聚合的明细,软删过期明细和聚合。
// 由 OpsCleanupService 的 cron 调度触发(共享 schedule 和 leader lock)。
//
// 幂等性:
// - watermark 保证已聚合的日期不会重复处理;
// - UpsertDailyRollupsFor 内部使用 ON CONFLICT DO UPDATE,同一日重复跑结果一致。
//
// 每一步失败都只记 slog.Warn,整体函数始终返回 nil 让后续步骤能继续跑
// (与 OpsCleanupService.runCleanupOnce 风格一致)。
func (s *ChannelMonitorService) RunDailyMaintenance(ctx context.Context) error {
now := time.Now().UTC()
today := now.Truncate(24 * time.Hour)
if err := s.runDailyAggregation(ctx, today); err != nil {
slog.Warn("channel_monitor: maintenance step failed",
"step", "aggregate", "error", err)
}
if err := s.cleanupOldHistory(ctx); err != nil {
slog.Warn("channel_monitor: maintenance step failed",
"step", "prune_history", "error", err)
}
if err := s.cleanupOldRollups(ctx, today); err != nil {
slog.Warn("channel_monitor: maintenance step failed",
"step", "prune_rollups", "error", err)
}
return nil
}
// runDailyAggregation 从 watermark+1 聚合到昨天(UTC)。
// 首次跑(watermark nil):从 today-monitorRollupRetentionDays 开始回填。
// 每次最多聚合 monitorMaintenanceMaxDaysPerRun 天,避免长事务。
func (s *ChannelMonitorService) runDailyAggregation(ctx context.Context, today time.Time) error {
watermark, err := s.repo.LoadAggregationWatermark(ctx)
if err != nil {
return fmt.Errorf("load watermark: %w", err)
}
start := s.resolveAggregationStart(watermark, today)
if !start.Before(today) {
return nil // 没有需要聚合的日期
}
iterations := 0
for d := start; d.Before(today); d = d.Add(24 * time.Hour) {
if iterations >= monitorMaintenanceMaxDaysPerRun {
slog.Info("channel_monitor: maintenance aggregation capped",
"max_days", monitorMaintenanceMaxDaysPerRun,
"next_resume", d.Format("2006-01-02"))
break
}
affected, upErr := s.repo.UpsertDailyRollupsFor(ctx, d)
if upErr != nil {
return fmt.Errorf("upsert rollups for %s: %w", d.Format("2006-01-02"), upErr)
}
if err := s.repo.UpdateAggregationWatermark(ctx, d); err != nil {
return fmt.Errorf("update watermark to %s: %w", d.Format("2006-01-02"), err)
}
slog.Info("channel_monitor: rollups upserted",
"date", d.Format("2006-01-02"), "affected_rows", affected)
iterations++
}
return nil
}
// resolveAggregationStart 计算本次聚合起点:
// - watermark == nil:today - monitorRollupRetentionDays(首次回填最多 30 天)
// - watermark != nil:*watermark + 1 day
func (s *ChannelMonitorService) resolveAggregationStart(watermark *time.Time, today time.Time) time.Time {
if watermark == nil {
return today.AddDate(0, 0, -monitorRollupRetentionDays)
}
return watermark.UTC().Truncate(24 * time.Hour).Add(24 * time.Hour)
}
// cleanupOldRollups 软删 bucket_date < today - monitorRollupRetentionDays 的日聚合行。
func (s *ChannelMonitorService) cleanupOldRollups(ctx context.Context, today time.Time) error {
cutoff := today.AddDate(0, 0, -monitorRollupRetentionDays)
deleted, err := s.repo.DeleteRollupsBefore(ctx, cutoff)
if err != nil {
return fmt.Errorf("delete rollups before %s: %w", cutoff.Format("2006-01-02"), err)
}
if deleted > 0 {
slog.Info("channel_monitor: rollups cleanup",
"deleted_rows", deleted, "before", cutoff.Format("2006-01-02"))
}
return nil
}
// ---------- helpers ----------
// decryptInPlace 把 ChannelMonitor.APIKey 从密文解密为明文。
// 解密失败时把字段清空 + 设置 APIKeyDecryptFailed=true(不返回错误,避免阻断列表渲染)。
// runner / RunCheck 必须读取该标志位并拒绝执行检测。
func (s *ChannelMonitorService) decryptInPlace(m *ChannelMonitor) {
if m == nil || m.APIKey == "" {
return
}
plain, err := s.encryptor.Decrypt(m.APIKey)
if err != nil {
slog.Warn("channel_monitor: decrypt api key failed",
"monitor_id", m.ID, "error", err)
m.APIKey = ""
m.APIKeyDecryptFailed = true
return
}
m.APIKey = plain
}
// applyMonitorUpdate 把 update params 中非 nil 的字段应用到 existing 上。
// APIKey 字段在调用方单独处理(涉及加密)。
//
// 行数稍超过 30:这是逐字段平铺的 dispatcher,每个 if 都是 1-3 行的"非 nil 则覆盖"模式,
// 拆分反而会增加跳转噪音、影响可读性,故保留为单函数。
func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
if p.Name != nil {
existing.Name = strings.TrimSpace(*p.Name)
}
if p.Provider != nil {
if err := validateProvider(*p.Provider); err != nil {
return err
}
existing.Provider = *p.Provider
}
if p.Endpoint != nil {
if err := validateEndpoint(*p.Endpoint); err != nil {
return err
}
existing.Endpoint = normalizeEndpoint(*p.Endpoint)
}
if p.PrimaryModel != nil {
existing.PrimaryModel = strings.TrimSpace(*p.PrimaryModel)
}
if p.ExtraModels != nil {
existing.ExtraModels = normalizeModels(*p.ExtraModels)
}
if p.GroupName != nil {
existing.GroupName = strings.TrimSpace(*p.GroupName)
}
if p.Enabled != nil {
existing.Enabled = *p.Enabled
}
if p.IntervalSeconds != nil {
if err := validateInterval(*p.IntervalSeconds); err != nil {
return err
}
existing.IntervalSeconds = *p.IntervalSeconds
}
return applyMonitorAdvancedUpdate(existing, p)
}
// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。
func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
if p.ClearTemplate {
existing.TemplateID = nil
} else if p.TemplateID != nil {
id := *p.TemplateID
existing.TemplateID = &id
}
if p.ExtraHeaders != nil {
if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
return err
}
existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
}
// BodyOverrideMode / BodyOverride 联合校验,和模板一致。
newMode := existing.BodyOverrideMode
newBody := existing.BodyOverride
if p.BodyOverrideMode != nil {
newMode = *p.BodyOverrideMode
}
if p.BodyOverride != nil {
newBody = *p.BodyOverride
}
if p.BodyOverrideMode != nil || p.BodyOverride != nil {
if err := validateBodyModeParams(newMode, newBody); err != nil {
return err
}
existing.BodyOverrideMode = defaultBodyMode(newMode)
existing.BodyOverride = newBody
}
return nil
}
package service
import (
"context"
"net"
"strings"
)
// SSRF 防护 helper:
// - validateEndpoint 在 admin 提交时阻止 http/loopback/私网/云元数据 URL
// - safeDialContext 在 socket 层再次校验真实 IP,防止 DNS rebinding
//
// 已知 cloud metadata hostname 拒绝列表(小写比较)。
var monitorBlockedHostnames = map[string]struct{}{
"localhost": {},
"localhost.localdomain": {},
"metadata": {},
"metadata.google.internal": {},
"metadata.goog": {},
"instance-data": {},
"instance-data.ec2.internal": {},
}
// CIDR 列表:包含所有需要拒绝的 IPv4/IPv6 段。
// 解析时只 panic 一次(启动时确认),生产路径只做 Contains。
var monitorBlockedCIDRs = mustParseCIDRs([]string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"169.254.0.0/16", // link-local(含云元数据 169.254.169.254)
"100.64.0.0/10", // CGNAT
"0.0.0.0/8", // "this network"
"::1/128", // IPv6 loopback
"fc00::/7", // IPv6 ULA
"fe80::/10", // IPv6 link-local
"::/128", // IPv6 unspecified
})
// monitorDialer 共享 Dialer,与 net/http 默认值对齐。
var monitorDialer = &net.Dialer{
Timeout: monitorDialTimeout,
KeepAlive: monitorDialKeepAlive,
}
// mustParseCIDRs 在包初始化时解析 CIDR 字符串,失败 panic。
func mustParseCIDRs(cidrs []string) []*net.IPNet {
out := make([]*net.IPNet, 0, len(cidrs))
for _, c := range cidrs {
_, n, err := net.ParseCIDR(c)
if err != nil {
panic("channel_monitor_ssrf: invalid CIDR " + c + ": " + err.Error())
}
out = append(out, n)
}
return out
}
// isBlockedHostname 判断 hostname 是否命中黑名单。
func isBlockedHostname(hostname string) bool {
if hostname == "" {
return true
}
_, blocked := monitorBlockedHostnames[strings.ToLower(hostname)]
return blocked
}
// isPrivateIP 判断 IP 是否落在禁止段(loopback/RFC1918/link-local/ULA 等)。
func isPrivateIP(ip net.IP) bool {
if ip == nil {
return true
}
if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return true
}
for _, n := range monitorBlockedCIDRs {
if n.Contains(ip) {
return true
}
}
return false
}
// isPrivateOrLoopbackHost 解析 hostname 的所有 A/AAAA 记录,
// 任一 IP 落在私网/loopback 段即认为不安全。
//
// hostname 是 IP 字面量时也走同一路径。
func isPrivateOrLoopbackHost(ctx context.Context, hostname string) (bool, error) {
if isBlockedHostname(hostname) {
return true, nil
}
// IP 字面量直接判断。
if ip := net.ParseIP(hostname); ip != nil {
return isPrivateIP(ip), nil
}
resolver := net.DefaultResolver
addrs, err := resolver.LookupIPAddr(ctx, hostname)
if err != nil {
return false, err
}
if len(addrs) == 0 {
return true, nil
}
for _, a := range addrs {
if isPrivateIP(a.IP) {
return true, nil
}
}
return false, nil
}
// safeDialContext 在真实 dial 前再次校验目标 IP,防止 DNS rebinding。
// 解析 hostname 后逐个 IP 尝试连接,命中私网即拒绝(即便 validateEndpoint 时返回的是公网 IP)。
func safeDialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
// 字面量 IP 走快速路径。
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address}
}
return monitorDialer.DialContext(ctx, network, address)
}
if isBlockedHostname(host) {
return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address}
}
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, &net.AddrError{Err: "no addresses for host", Addr: host}
}
var lastErr error
for _, a := range addrs {
if isPrivateIP(a.IP) {
lastErr = &net.AddrError{Err: "blocked by SSRF policy", Addr: a.IP.String()}
continue
}
conn, err := monitorDialer.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port))
if err == nil {
return conn, nil
}
lastErr = err
}
if lastErr == nil {
lastErr = &net.AddrError{Err: "no usable addresses", Addr: host}
}
return nil, lastErr
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment