Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ac114738
Unverified
Commit
ac114738
authored
Apr 24, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 24, 2026
Browse files
Merge pull request #1850 from touwaeriol/feat/channel-insights
feat(monitor): channel monitor with available channels & feature flags
parents
0a80ec80
09fd83ab
Changes
151
Show whitespace changes
Inline
Side-by-side
backend/internal/handler/setting_handler.go
View file @
ac114738
...
...
@@ -70,5 +70,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AccountQuotaNotifyEnabled
:
settings
.
AccountQuotaNotifyEnabled
,
BalanceLowNotifyThreshold
:
settings
.
BalanceLowNotifyThreshold
,
BalanceLowNotifyRechargeURL
:
settings
.
BalanceLowNotifyRechargeURL
,
ChannelMonitorEnabled
:
settings
.
ChannelMonitorEnabled
,
ChannelMonitorDefaultIntervalSeconds
:
settings
.
ChannelMonitorDefaultIntervalSeconds
,
AvailableChannelsEnabled
:
settings
.
AvailableChannelsEnabled
,
})
}
backend/internal/handler/wire.go
View file @
ac114738
...
...
@@ -34,6 +34,8 @@ func ProvideAdminHandlers(
apiKeyHandler
*
admin
.
AdminAPIKeyHandler
,
scheduledTestHandler
*
admin
.
ScheduledTestHandler
,
channelHandler
*
admin
.
ChannelHandler
,
channelMonitorHandler
*
admin
.
ChannelMonitorHandler
,
channelMonitorTemplateHandler
*
admin
.
ChannelMonitorRequestTemplateHandler
,
paymentHandler
*
admin
.
PaymentHandler
,
)
*
AdminHandlers
{
return
&
AdminHandlers
{
...
...
@@ -62,6 +64,8 @@ func ProvideAdminHandlers(
APIKey
:
apiKeyHandler
,
ScheduledTest
:
scheduledTestHandler
,
Channel
:
channelHandler
,
ChannelMonitor
:
channelMonitorHandler
,
ChannelMonitorTemplate
:
channelMonitorTemplateHandler
,
Payment
:
paymentHandler
,
}
}
...
...
@@ -85,6 +89,7 @@ func ProvideHandlers(
redeemHandler
*
RedeemHandler
,
subscriptionHandler
*
SubscriptionHandler
,
announcementHandler
*
AnnouncementHandler
,
channelMonitorUserHandler
*
ChannelMonitorUserHandler
,
adminHandlers
*
AdminHandlers
,
gatewayHandler
*
GatewayHandler
,
openaiGatewayHandler
*
OpenAIGatewayHandler
,
...
...
@@ -92,6 +97,7 @@ func ProvideHandlers(
totpHandler
*
TotpHandler
,
paymentHandler
*
PaymentHandler
,
paymentWebhookHandler
*
PaymentWebhookHandler
,
availableChannelHandler
*
AvailableChannelHandler
,
_
*
service
.
IdempotencyCoordinator
,
_
*
service
.
IdempotencyCleanupService
,
)
*
Handlers
{
...
...
@@ -103,6 +109,7 @@ func ProvideHandlers(
Redeem
:
redeemHandler
,
Subscription
:
subscriptionHandler
,
Announcement
:
announcementHandler
,
ChannelMonitor
:
channelMonitorUserHandler
,
Admin
:
adminHandlers
,
Gateway
:
gatewayHandler
,
OpenAIGateway
:
openaiGatewayHandler
,
...
...
@@ -110,6 +117,7 @@ func ProvideHandlers(
Totp
:
totpHandler
,
Payment
:
paymentHandler
,
PaymentWebhook
:
paymentWebhookHandler
,
AvailableChannel
:
availableChannelHandler
,
}
}
...
...
@@ -123,12 +131,14 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler
,
NewSubscriptionHandler
,
NewAnnouncementHandler
,
NewChannelMonitorUserHandler
,
NewGatewayHandler
,
NewOpenAIGatewayHandler
,
NewTotpHandler
,
ProvideSettingHandler
,
NewPaymentHandler
,
NewPaymentWebhookHandler
,
NewAvailableChannelHandler
,
// Admin handlers
admin
.
NewDashboardHandler
,
...
...
@@ -156,6 +166,8 @@ var ProviderSet = wire.NewSet(
admin
.
NewAdminAPIKeyHandler
,
admin
.
NewScheduledTestHandler
,
admin
.
NewChannelHandler
,
admin
.
NewChannelMonitorHandler
,
admin
.
NewChannelMonitorRequestTemplateHandler
,
admin
.
NewPaymentHandler
,
// AdminHandlers and Handlers constructors
...
...
backend/internal/repository/channel_monitor_repo.go
0 → 100644
View file @
ac114738
package
repository
import
(
"context"
"database/sql"
"fmt"
"strings"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// channelMonitorRepository 实现 service.ChannelMonitorRepository。
//
// 选型说明:
// - CRUD 走 ent,复用项目的事务上下文支持
// - 聚合查询(latest per model / availability)走原生 SQL,避免 ent 在 GROUP BY 上
// 的样板代码,并保证索引能被命中
type
channelMonitorRepository
struct
{
client
*
dbent
.
Client
db
*
sql
.
DB
}
// NewChannelMonitorRepository 创建仓储实例。
func
NewChannelMonitorRepository
(
client
*
dbent
.
Client
,
db
*
sql
.
DB
)
service
.
ChannelMonitorRepository
{
return
&
channelMonitorRepository
{
client
:
client
,
db
:
db
}
}
// ---------- CRUD ----------
func
(
r
*
channelMonitorRepository
)
Create
(
ctx
context
.
Context
,
m
*
service
.
ChannelMonitor
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
ChannelMonitor
.
Create
()
.
SetName
(
m
.
Name
)
.
SetProvider
(
channelmonitor
.
Provider
(
m
.
Provider
))
.
SetEndpoint
(
m
.
Endpoint
)
.
SetAPIKeyEncrypted
(
m
.
APIKey
)
.
// 调用方传入的已是密文
SetPrimaryModel
(
m
.
PrimaryModel
)
.
SetExtraModels
(
emptySliceIfNil
(
m
.
ExtraModels
))
.
SetGroupName
(
m
.
GroupName
)
.
SetEnabled
(
m
.
Enabled
)
.
SetIntervalSeconds
(
m
.
IntervalSeconds
)
.
SetCreatedBy
(
m
.
CreatedBy
)
.
SetExtraHeaders
(
emptyHeadersIfNilRepo
(
m
.
ExtraHeaders
))
.
SetBodyOverrideMode
(
defaultBodyModeRepo
(
m
.
BodyOverrideMode
))
if
m
.
TemplateID
!=
nil
{
builder
=
builder
.
SetTemplateID
(
*
m
.
TemplateID
)
}
if
m
.
BodyOverride
!=
nil
{
builder
=
builder
.
SetBodyOverride
(
m
.
BodyOverride
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorNotFound
,
nil
)
}
m
.
ID
=
created
.
ID
m
.
CreatedAt
=
created
.
CreatedAt
m
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
channelMonitorRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ChannelMonitor
,
error
)
{
row
,
err
:=
r
.
client
.
ChannelMonitor
.
Query
()
.
Where
(
channelmonitor
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorNotFound
,
nil
)
}
return
entToServiceMonitor
(
row
),
nil
}
func
(
r
*
channelMonitorRepository
)
Update
(
ctx
context
.
Context
,
m
*
service
.
ChannelMonitor
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
updater
:=
client
.
ChannelMonitor
.
UpdateOneID
(
m
.
ID
)
.
SetName
(
m
.
Name
)
.
SetProvider
(
channelmonitor
.
Provider
(
m
.
Provider
))
.
SetEndpoint
(
m
.
Endpoint
)
.
SetAPIKeyEncrypted
(
m
.
APIKey
)
.
SetPrimaryModel
(
m
.
PrimaryModel
)
.
SetExtraModels
(
emptySliceIfNil
(
m
.
ExtraModels
))
.
SetGroupName
(
m
.
GroupName
)
.
SetEnabled
(
m
.
Enabled
)
.
SetIntervalSeconds
(
m
.
IntervalSeconds
)
.
SetExtraHeaders
(
emptyHeadersIfNilRepo
(
m
.
ExtraHeaders
))
.
SetBodyOverrideMode
(
defaultBodyModeRepo
(
m
.
BodyOverrideMode
))
if
m
.
TemplateID
!=
nil
{
updater
=
updater
.
SetTemplateID
(
*
m
.
TemplateID
)
}
else
{
updater
=
updater
.
ClearTemplateID
()
}
if
m
.
BodyOverride
!=
nil
{
updater
=
updater
.
SetBodyOverride
(
m
.
BodyOverride
)
}
else
{
updater
=
updater
.
ClearBodyOverride
()
}
updated
,
err
:=
updater
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorNotFound
,
nil
)
}
m
.
UpdatedAt
=
updated
.
UpdatedAt
return
nil
}
func
(
r
*
channelMonitorRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
if
err
:=
client
.
ChannelMonitor
.
DeleteOneID
(
id
)
.
Exec
(
ctx
);
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorNotFound
,
nil
)
}
return
nil
}
func
(
r
*
channelMonitorRepository
)
List
(
ctx
context
.
Context
,
params
service
.
ChannelMonitorListParams
)
([]
*
service
.
ChannelMonitor
,
int64
,
error
)
{
q
:=
r
.
client
.
ChannelMonitor
.
Query
()
if
params
.
Provider
!=
""
{
q
=
q
.
Where
(
channelmonitor
.
ProviderEQ
(
channelmonitor
.
Provider
(
params
.
Provider
)))
}
if
params
.
Enabled
!=
nil
{
q
=
q
.
Where
(
channelmonitor
.
EnabledEQ
(
*
params
.
Enabled
))
}
if
s
:=
strings
.
TrimSpace
(
params
.
Search
);
s
!=
""
{
q
=
q
.
Where
(
channelmonitor
.
Or
(
channelmonitor
.
NameContainsFold
(
s
),
channelmonitor
.
GroupNameContainsFold
(
s
),
channelmonitor
.
PrimaryModelContainsFold
(
s
),
))
}
total
,
err
:=
q
.
Count
(
ctx
)
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"count monitors: %w"
,
err
)
}
pageSize
:=
params
.
PageSize
if
pageSize
<=
0
{
pageSize
=
20
}
page
:=
params
.
Page
if
page
<=
0
{
page
=
1
}
rows
,
err
:=
q
.
Order
(
dbent
.
Desc
(
channelmonitor
.
FieldID
))
.
Offset
((
page
-
1
)
*
pageSize
)
.
Limit
(
pageSize
)
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"list monitors: %w"
,
err
)
}
out
:=
make
([]
*
service
.
ChannelMonitor
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
out
=
append
(
out
,
entToServiceMonitor
(
row
))
}
return
out
,
int64
(
total
),
nil
}
// ---------- 调度器辅助 ----------
func
(
r
*
channelMonitorRepository
)
ListEnabled
(
ctx
context
.
Context
)
([]
*
service
.
ChannelMonitor
,
error
)
{
rows
,
err
:=
r
.
client
.
ChannelMonitor
.
Query
()
.
Where
(
channelmonitor
.
EnabledEQ
(
true
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list enabled monitors: %w"
,
err
)
}
out
:=
make
([]
*
service
.
ChannelMonitor
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
out
=
append
(
out
,
entToServiceMonitor
(
row
))
}
return
out
,
nil
}
func
(
r
*
channelMonitorRepository
)
MarkChecked
(
ctx
context
.
Context
,
id
int64
,
checkedAt
time
.
Time
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
if
err
:=
client
.
ChannelMonitor
.
UpdateOneID
(
id
)
.
SetLastCheckedAt
(
checkedAt
)
.
Exec
(
ctx
);
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorNotFound
,
nil
)
}
return
nil
}
func
(
r
*
channelMonitorRepository
)
InsertHistoryBatch
(
ctx
context
.
Context
,
rows
[]
*
service
.
ChannelMonitorHistoryRow
)
error
{
if
len
(
rows
)
==
0
{
return
nil
}
client
:=
clientFromContext
(
ctx
,
r
.
client
)
bulk
:=
make
([]
*
dbent
.
ChannelMonitorHistoryCreate
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
c
:=
client
.
ChannelMonitorHistory
.
Create
()
.
SetMonitorID
(
row
.
MonitorID
)
.
SetModel
(
row
.
Model
)
.
SetStatus
(
channelmonitorhistory
.
Status
(
row
.
Status
))
.
SetMessage
(
row
.
Message
)
.
SetCheckedAt
(
row
.
CheckedAt
)
if
row
.
LatencyMs
!=
nil
{
c
=
c
.
SetLatencyMs
(
*
row
.
LatencyMs
)
}
if
row
.
PingLatencyMs
!=
nil
{
c
=
c
.
SetPingLatencyMs
(
*
row
.
PingLatencyMs
)
}
bulk
=
append
(
bulk
,
c
)
}
if
_
,
err
:=
client
.
ChannelMonitorHistory
.
CreateBulk
(
bulk
...
)
.
Save
(
ctx
);
err
!=
nil
{
return
fmt
.
Errorf
(
"insert history bulk: %w"
,
err
)
}
return
nil
}
// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批,
// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。
func
(
r
*
channelMonitorRepository
)
DeleteHistoryBefore
(
ctx
context
.
Context
,
before
time
.
Time
)
(
int64
,
error
)
{
return
deleteChannelMonitorBatched
(
ctx
,
r
.
db
,
channelMonitorPruneHistorySQL
,
before
)
}
// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。
// model 为空时不过滤;非空时只返回该模型的记录。
func
(
r
*
channelMonitorRepository
)
ListHistory
(
ctx
context
.
Context
,
monitorID
int64
,
model
string
,
limit
int
)
([]
*
service
.
ChannelMonitorHistoryEntry
,
error
)
{
q
:=
r
.
client
.
ChannelMonitorHistory
.
Query
()
.
Where
(
channelmonitorhistory
.
MonitorIDEQ
(
monitorID
))
if
strings
.
TrimSpace
(
model
)
!=
""
{
q
=
q
.
Where
(
channelmonitorhistory
.
ModelEQ
(
model
))
}
rows
,
err
:=
q
.
Order
(
dbent
.
Desc
(
channelmonitorhistory
.
FieldCheckedAt
))
.
Limit
(
limit
)
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list history: %w"
,
err
)
}
out
:=
make
([]
*
service
.
ChannelMonitorHistoryEntry
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
entry
:=
&
service
.
ChannelMonitorHistoryEntry
{
ID
:
row
.
ID
,
Model
:
row
.
Model
,
Status
:
string
(
row
.
Status
),
LatencyMs
:
row
.
LatencyMs
,
PingLatencyMs
:
row
.
PingLatencyMs
,
Message
:
row
.
Message
,
CheckedAt
:
row
.
CheckedAt
,
}
out
=
append
(
out
,
entry
)
}
return
out
,
nil
}
// ---------- 用户视图聚合(原生 SQL) ----------
// ListLatestPerModel 用 DISTINCT ON 取每个 (monitor_id, model) 的最近一条记录。
// 借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
func
(
r
*
channelMonitorRepository
)
ListLatestPerModel
(
ctx
context
.
Context
,
monitorID
int64
)
([]
*
service
.
ChannelMonitorLatest
,
error
)
{
const
q
=
`
SELECT DISTINCT ON (model)
model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = $1
ORDER BY model, checked_at DESC
`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
monitorID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query latest per model: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
out
:=
make
([]
*
service
.
ChannelMonitorLatest
,
0
)
for
rows
.
Next
()
{
l
:=
&
service
.
ChannelMonitorLatest
{}
var
latency
,
ping
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
l
.
Model
,
&
l
.
Status
,
&
latency
,
&
ping
,
&
l
.
CheckedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan latest row: %w"
,
err
)
}
assignNullInt
(
&
l
.
LatencyMs
,
latency
)
assignNullInt
(
&
l
.
PingLatencyMs
,
ping
)
out
=
append
(
out
,
l
)
}
return
out
,
rows
.
Err
()
}
// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。
// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。
func
assignNullInt
(
dst
**
int
,
n
sql
.
NullInt64
)
{
if
!
n
.
Valid
{
return
}
v
:=
int
(
n
.
Int64
)
*
dst
=
&
v
}
// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
// "可用" = status IN (operational, degraded)。
//
// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories,
// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。
func
(
r
*
channelMonitorRepository
)
ComputeAvailability
(
ctx
context
.
Context
,
monitorID
int64
,
windowDays
int
)
([]
*
service
.
ChannelMonitorAvailability
,
error
)
{
if
windowDays
<=
0
{
windowDays
=
7
}
const
q
=
`
SELECT model,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
CASE WHEN COUNT(latency_ms) > 0
THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
ELSE NULL END AS avg_latency_ms
FROM channel_monitor_histories
WHERE monitor_id = $1
AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY model
`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
monitorID
,
windowDays
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query availability: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
out
:=
make
([]
*
service
.
ChannelMonitorAvailability
,
0
)
for
rows
.
Next
()
{
row
,
err
:=
scanAvailabilityRow
(
rows
,
windowDays
)
if
err
!=
nil
{
return
nil
,
err
}
out
=
append
(
out
,
row
)
}
return
out
,
rows
.
Err
()
}
// scanAvailabilityRow 把单行 (model, total, ok, avg_latency) 扫描为 ChannelMonitorAvailability。
// 仅服务于 ComputeAvailability(4 列);批量版本因为多一列 monitor_id 直接 inline 调 finalizeAvailabilityRow。
func
scanAvailabilityRow
(
rows
interface
{
Scan
(
...
any
)
error
},
windowDays
int
)
(
*
service
.
ChannelMonitorAvailability
,
error
)
{
row
:=
&
service
.
ChannelMonitorAvailability
{
WindowDays
:
windowDays
}
var
avgLatency
sql
.
NullFloat64
if
err
:=
rows
.
Scan
(
&
row
.
Model
,
&
row
.
TotalChecks
,
&
row
.
OperationalChecks
,
&
avgLatency
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan availability row: %w"
,
err
)
}
finalizeAvailabilityRow
(
row
,
avgLatency
)
return
row
,
nil
}
// finalizeAvailabilityRow 根据 OperationalChecks/TotalChecks 算出可用率,
// 并把 sql.NullFloat64 的平均延迟解包为 *int。两处复用避免维护漂移。
func
finalizeAvailabilityRow
(
row
*
service
.
ChannelMonitorAvailability
,
avgLatency
sql
.
NullFloat64
)
{
if
row
.
TotalChecks
>
0
{
row
.
AvailabilityPct
=
float64
(
row
.
OperationalChecks
)
*
100.0
/
float64
(
row
.
TotalChecks
)
}
if
avgLatency
.
Valid
{
v
:=
int
(
avgLatency
.
Float64
)
row
.
AvgLatencyMs
=
&
v
}
}
// ListLatestForMonitorIDs 一次性查询多个监控的"每个 (monitor_id, model) 最近一条"记录。
// 利用 PG 的 DISTINCT ON 特性,借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
func
(
r
*
channelMonitorRepository
)
ListLatestForMonitorIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
(
map
[
int64
][]
*
service
.
ChannelMonitorLatest
,
error
)
{
out
:=
make
(
map
[
int64
][]
*
service
.
ChannelMonitorLatest
,
len
(
ids
))
if
len
(
ids
)
==
0
{
return
out
,
nil
}
const
q
=
`
SELECT DISTINCT ON (monitor_id, model)
monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
ORDER BY monitor_id, model, checked_at DESC
`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
pq
.
Array
(
ids
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query latest batch: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
monitorID
int64
l
:=
&
service
.
ChannelMonitorLatest
{}
var
latency
,
ping
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
monitorID
,
&
l
.
Model
,
&
l
.
Status
,
&
latency
,
&
ping
,
&
l
.
CheckedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan latest batch row: %w"
,
err
)
}
assignNullInt
(
&
l
.
LatencyMs
,
latency
)
assignNullInt
(
&
l
.
PingLatencyMs
,
ping
)
out
[
monitorID
]
=
append
(
out
[
monitorID
],
l
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。
// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。
// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单,
// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。
//
// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。
// 空 ids / 空 primaryModels 返回空 map,不报错。
func
(
r
*
channelMonitorRepository
)
ListRecentHistoryForMonitors
(
ctx
context
.
Context
,
ids
[]
int64
,
primaryModels
map
[
int64
]
string
,
perMonitorLimit
int
,
)
(
map
[
int64
][]
*
service
.
ChannelMonitorHistoryEntry
,
error
)
{
out
:=
make
(
map
[
int64
][]
*
service
.
ChannelMonitorHistoryEntry
,
len
(
ids
))
pairIDs
,
pairModels
:=
buildMonitorModelPairs
(
ids
,
primaryModels
)
if
len
(
pairIDs
)
==
0
{
return
out
,
nil
}
perMonitorLimit
=
clampTimelineLimit
(
perMonitorLimit
)
const
q
=
`
WITH targets AS (
SELECT unnest($1::bigint[]) AS monitor_id,
unnest($2::text[]) AS model
),
ranked AS (
SELECT h.monitor_id,
h.status,
h.latency_ms,
h.ping_latency_ms,
h.checked_at,
ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn
FROM channel_monitor_histories h
JOIN targets t
ON t.monitor_id = h.monitor_id AND t.model = h.model
)
SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
FROM ranked
WHERE rn <= $3
ORDER BY monitor_id, checked_at DESC
`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
pq
.
Array
(
pairIDs
),
pq
.
Array
(
pairModels
),
perMonitorLimit
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query recent history batch: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
monitorID
int64
entry
:=
&
service
.
ChannelMonitorHistoryEntry
{}
var
latency
,
ping
sql
.
NullInt64
if
err
:=
rows
.
Scan
(
&
monitorID
,
&
entry
.
Status
,
&
latency
,
&
ping
,
&
entry
.
CheckedAt
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan recent history row: %w"
,
err
)
}
assignNullInt
(
&
entry
.
LatencyMs
,
latency
)
assignNullInt
(
&
entry
.
PingLatencyMs
,
ping
)
out
[
monitorID
]
=
append
(
out
[
monitorID
],
entry
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。
// 保证两个数组长度一致且一一对应,供 unnest 展开。
func
buildMonitorModelPairs
(
ids
[]
int64
,
primaryModels
map
[
int64
]
string
)
([]
int64
,
[]
string
)
{
if
len
(
ids
)
==
0
||
len
(
primaryModels
)
==
0
{
return
nil
,
nil
}
pairIDs
:=
make
([]
int64
,
0
,
len
(
ids
))
pairModels
:=
make
([]
string
,
0
,
len
(
ids
))
for
_
,
id
:=
range
ids
{
model
,
ok
:=
primaryModels
[
id
]
if
!
ok
||
strings
.
TrimSpace
(
model
)
==
""
{
continue
}
pairIDs
=
append
(
pairIDs
,
id
)
pairModels
=
append
(
pairModels
,
model
)
}
return
pairIDs
,
pairModels
}
// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。
// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。
const
(
timelineLimitMin
=
1
timelineLimitMax
=
200
)
// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。
func
clampTimelineLimit
(
n
int
)
int
{
if
n
<
timelineLimitMin
{
return
timelineLimitMin
}
if
n
>
timelineLimitMax
{
return
timelineLimitMax
}
return
n
}
// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。
func
(
r
*
channelMonitorRepository
)
ComputeAvailabilityForMonitors
(
ctx
context
.
Context
,
ids
[]
int64
,
windowDays
int
)
(
map
[
int64
][]
*
service
.
ChannelMonitorAvailability
,
error
)
{
out
:=
make
(
map
[
int64
][]
*
service
.
ChannelMonitorAvailability
,
len
(
ids
))
if
len
(
ids
)
==
0
{
return
out
,
nil
}
if
windowDays
<=
0
{
windowDays
=
7
}
const
q
=
`
SELECT monitor_id,
model,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
CASE WHEN COUNT(latency_ms) > 0
THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
ELSE NULL END AS avg_latency_ms
FROM channel_monitor_histories
WHERE monitor_id = ANY($1)
AND checked_at >= NOW() - ($2::int || ' days')::interval
GROUP BY monitor_id, model
`
rows
,
err
:=
r
.
db
.
QueryContext
(
ctx
,
q
,
pq
.
Array
(
ids
),
windowDays
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query availability batch: %w"
,
err
)
}
defer
func
()
{
_
=
rows
.
Close
()
}()
for
rows
.
Next
()
{
var
monitorID
int64
row
:=
&
service
.
ChannelMonitorAvailability
{
WindowDays
:
windowDays
}
var
avgLatency
sql
.
NullFloat64
if
err
:=
rows
.
Scan
(
&
monitorID
,
&
row
.
Model
,
&
row
.
TotalChecks
,
&
row
.
OperationalChecks
,
&
avgLatency
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"scan availability batch row: %w"
,
err
)
}
// 批量查询多了首列 monitor_id;其余字段的可用率/平均延迟换算与单 monitor 版本一致,
// 抽出 finalizeAvailabilityRow 复用,避免两处分别维护除法与 NullFloat 解包。
finalizeAvailabilityRow
(
row
,
avgLatency
)
out
[
monitorID
]
=
append
(
out
[
monitorID
],
row
)
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
// ---------- 聚合维护 ----------
// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细
// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
// 重复执行只会用最新统计覆盖;
// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
func
(
r
*
channelMonitorRepository
)
UpsertDailyRollupsFor
(
ctx
context
.
Context
,
targetDate
time
.
Time
)
(
int64
,
error
)
{
const
q
=
`
INSERT INTO channel_monitor_daily_rollups (
monitor_id, model, bucket_date,
total_checks, ok_count,
operational_count, degraded_count, failed_count, error_count,
sum_latency_ms, count_latency,
sum_ping_latency_ms, count_ping_latency,
computed_at
)
SELECT
monitor_id,
model,
$1::date AS bucket_date,
COUNT(*) AS total_checks,
COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
COUNT(*) FILTER (WHERE status = 'operational') AS operational_count,
COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count,
COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
COUNT(*) FILTER (WHERE status = 'error') AS error_count,
COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
COUNT(latency_ms) AS count_latency,
COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms,
COUNT(ping_latency_ms) AS count_ping_latency,
NOW()
FROM channel_monitor_histories
WHERE checked_at >= $1::date
AND checked_at < ($1::date + INTERVAL '1 day')
GROUP BY monitor_id, model
ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
total_checks = EXCLUDED.total_checks,
ok_count = EXCLUDED.ok_count,
operational_count = EXCLUDED.operational_count,
degraded_count = EXCLUDED.degraded_count,
failed_count = EXCLUDED.failed_count,
error_count = EXCLUDED.error_count,
sum_latency_ms = EXCLUDED.sum_latency_ms,
count_latency = EXCLUDED.count_latency,
sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
count_ping_latency = EXCLUDED.count_ping_latency,
computed_at = NOW()
`
res
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
q
,
targetDate
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"upsert daily rollups for %s: %w"
,
targetDate
.
Format
(
"2006-01-02"
),
err
)
}
n
,
err
:=
res
.
RowsAffected
()
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"rows affected (upsert rollups): %w"
,
err
)
}
return
n
,
nil
}
// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。
func
(
r
*
channelMonitorRepository
)
DeleteRollupsBefore
(
ctx
context
.
Context
,
beforeDate
time
.
Time
)
(
int64
,
error
)
{
return
deleteChannelMonitorBatched
(
ctx
,
r
.
db
,
channelMonitorPruneRollupSQL
,
beforeDate
)
}
// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000,
// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。
const
channelMonitorPruneBatchSize
=
5000
// channelMonitorPruneHistorySQL 分批物理删明细表过期行。
const
channelMonitorPruneHistorySQL
=
`
WITH batch AS (
SELECT id FROM channel_monitor_histories
WHERE checked_at < $1
ORDER BY id
LIMIT $2
)
DELETE FROM channel_monitor_histories
WHERE id IN (SELECT id FROM batch)
`
// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型
// 保证与 DATE 列一致比较。
const
channelMonitorPruneRollupSQL
=
`
WITH batch AS (
SELECT id FROM channel_monitor_daily_rollups
WHERE bucket_date < $1::date
ORDER BY id
LIMIT $2
)
DELETE FROM channel_monitor_daily_rollups
WHERE id IN (SELECT id FROM batch)
`
// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。
// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。
func
deleteChannelMonitorBatched
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
query
string
,
cutoff
time
.
Time
)
(
int64
,
error
)
{
var
total
int64
for
{
res
,
err
:=
db
.
ExecContext
(
ctx
,
query
,
cutoff
,
channelMonitorPruneBatchSize
)
if
err
!=
nil
{
return
total
,
fmt
.
Errorf
(
"channel_monitor prune batch: %w"
,
err
)
}
affected
,
err
:=
res
.
RowsAffected
()
if
err
!=
nil
{
return
total
,
fmt
.
Errorf
(
"channel_monitor prune rows affected: %w"
,
err
)
}
total
+=
affected
if
affected
==
0
{
break
}
}
return
total
,
nil
}
// LoadAggregationWatermark 读 watermark 表(id=1)。
// watermark 表不是 ent schema(只有一行),直接走原生 SQL。
// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略
func
(
r
*
channelMonitorRepository
)
LoadAggregationWatermark
(
ctx
context
.
Context
)
(
*
time
.
Time
,
error
)
{
const
q
=
`SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1`
var
t
sql
.
NullTime
if
err
:=
r
.
db
.
QueryRowContext
(
ctx
,
q
)
.
Scan
(
&
t
);
err
!=
nil
{
if
err
==
sql
.
ErrNoRows
{
return
nil
,
nil
}
return
nil
,
fmt
.
Errorf
(
"load aggregation watermark: %w"
,
err
)
}
if
!
t
.
Valid
{
return
nil
,
nil
}
return
&
t
.
Time
,
nil
}
// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。
// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。
func
(
r
*
channelMonitorRepository
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
date
time
.
Time
)
error
{
const
q
=
`
INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
VALUES (1, $1::date, NOW())
ON CONFLICT (id) DO UPDATE SET
last_aggregated_date = EXCLUDED.last_aggregated_date,
updated_at = NOW()
`
if
_
,
err
:=
r
.
db
.
ExecContext
(
ctx
,
q
,
date
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update aggregation watermark: %w"
,
err
)
}
return
nil
}
// ---------- helpers ----------
func
entToServiceMonitor
(
row
*
dbent
.
ChannelMonitor
)
*
service
.
ChannelMonitor
{
if
row
==
nil
{
return
nil
}
extras
:=
row
.
ExtraModels
if
extras
==
nil
{
extras
=
[]
string
{}
}
headers
:=
row
.
ExtraHeaders
if
headers
==
nil
{
headers
=
map
[
string
]
string
{}
}
out
:=
&
service
.
ChannelMonitor
{
ID
:
row
.
ID
,
Name
:
row
.
Name
,
Provider
:
string
(
row
.
Provider
),
Endpoint
:
row
.
Endpoint
,
APIKey
:
row
.
APIKeyEncrypted
,
// 仍为密文,service 层负责解密
PrimaryModel
:
row
.
PrimaryModel
,
ExtraModels
:
extras
,
GroupName
:
row
.
GroupName
,
Enabled
:
row
.
Enabled
,
IntervalSeconds
:
row
.
IntervalSeconds
,
LastCheckedAt
:
row
.
LastCheckedAt
,
CreatedBy
:
row
.
CreatedBy
,
CreatedAt
:
row
.
CreatedAt
,
UpdatedAt
:
row
.
UpdatedAt
,
ExtraHeaders
:
headers
,
BodyOverrideMode
:
row
.
BodyOverrideMode
,
BodyOverride
:
row
.
BodyOverride
,
}
if
row
.
TemplateID
!=
nil
{
id
:=
*
row
.
TemplateID
out
.
TemplateID
=
&
id
}
return
out
}
// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致,
// repo 独立一份避免 import 循环。
func
emptyHeadersIfNilRepo
(
h
map
[
string
]
string
)
map
[
string
]
string
{
if
h
==
nil
{
return
map
[
string
]
string
{}
}
return
h
}
// defaultBodyModeRepo 空串归一为 off(同上不循环)。
func
defaultBodyModeRepo
(
mode
string
)
string
{
if
mode
==
""
{
return
"off"
}
return
mode
}
func
emptySliceIfNil
(
in
[]
string
)
[]
string
{
if
in
==
nil
{
return
[]
string
{}
}
return
in
}
backend/internal/repository/channel_monitor_template_repo.go
0 → 100644
View file @
ac114738
package
repository
import
(
"context"
"database/sql"
"fmt"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/channelmonitor"
"github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。
// 与 channelMonitorRepository 分开一个文件,职责清晰。
type
channelMonitorRequestTemplateRepository
struct
{
client
*
dbent
.
Client
db
*
sql
.
DB
}
// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。
func
NewChannelMonitorRequestTemplateRepository
(
client
*
dbent
.
Client
,
db
*
sql
.
DB
)
service
.
ChannelMonitorRequestTemplateRepository
{
return
&
channelMonitorRequestTemplateRepository
{
client
:
client
,
db
:
db
}
}
// ---------- CRUD ----------
func
(
r
*
channelMonitorRequestTemplateRepository
)
Create
(
ctx
context
.
Context
,
t
*
service
.
ChannelMonitorRequestTemplate
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
builder
:=
client
.
ChannelMonitorRequestTemplate
.
Create
()
.
SetName
(
t
.
Name
)
.
SetProvider
(
channelmonitorrequesttemplate
.
Provider
(
t
.
Provider
))
.
SetDescription
(
t
.
Description
)
.
SetExtraHeaders
(
emptyHeadersIfNilRepo
(
t
.
ExtraHeaders
))
.
SetBodyOverrideMode
(
defaultBodyModeRepo
(
t
.
BodyOverrideMode
))
if
t
.
BodyOverride
!=
nil
{
builder
=
builder
.
SetBodyOverride
(
t
.
BodyOverride
)
}
created
,
err
:=
builder
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorTemplateNotFound
,
nil
)
}
t
.
ID
=
created
.
ID
t
.
CreatedAt
=
created
.
CreatedAt
t
.
UpdatedAt
=
created
.
UpdatedAt
return
nil
}
func
(
r
*
channelMonitorRequestTemplateRepository
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ChannelMonitorRequestTemplate
,
error
)
{
row
,
err
:=
r
.
client
.
ChannelMonitorRequestTemplate
.
Query
()
.
Where
(
channelmonitorrequesttemplate
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
return
nil
,
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorTemplateNotFound
,
nil
)
}
return
entToServiceTemplate
(
row
),
nil
}
func
(
r
*
channelMonitorRequestTemplateRepository
)
Update
(
ctx
context
.
Context
,
t
*
service
.
ChannelMonitorRequestTemplate
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
updater
:=
client
.
ChannelMonitorRequestTemplate
.
UpdateOneID
(
t
.
ID
)
.
SetName
(
t
.
Name
)
.
SetDescription
(
t
.
Description
)
.
SetExtraHeaders
(
emptyHeadersIfNilRepo
(
t
.
ExtraHeaders
))
.
SetBodyOverrideMode
(
defaultBodyModeRepo
(
t
.
BodyOverrideMode
))
if
t
.
BodyOverride
!=
nil
{
updater
=
updater
.
SetBodyOverride
(
t
.
BodyOverride
)
}
else
{
updater
=
updater
.
ClearBodyOverride
()
}
updated
,
err
:=
updater
.
Save
(
ctx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorTemplateNotFound
,
nil
)
}
t
.
UpdatedAt
=
updated
.
UpdatedAt
return
nil
}
func
(
r
*
channelMonitorRequestTemplateRepository
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
if
err
:=
client
.
ChannelMonitorRequestTemplate
.
DeleteOneID
(
id
)
.
Exec
(
ctx
);
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorTemplateNotFound
,
nil
)
}
return
nil
}
func
(
r
*
channelMonitorRequestTemplateRepository
)
List
(
ctx
context
.
Context
,
params
service
.
ChannelMonitorRequestTemplateListParams
)
([]
*
service
.
ChannelMonitorRequestTemplate
,
error
)
{
q
:=
r
.
client
.
ChannelMonitorRequestTemplate
.
Query
()
if
params
.
Provider
!=
""
{
q
=
q
.
Where
(
channelmonitorrequesttemplate
.
ProviderEQ
(
channelmonitorrequesttemplate
.
Provider
(
params
.
Provider
)))
}
rows
,
err
:=
q
.
Order
(
dbent
.
Asc
(
channelmonitorrequesttemplate
.
FieldProvider
),
dbent
.
Asc
(
channelmonitorrequesttemplate
.
FieldName
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list monitor templates: %w"
,
err
)
}
out
:=
make
([]
*
service
.
ChannelMonitorRequestTemplate
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
out
=
append
(
out
,
entToServiceTemplate
(
row
))
}
return
out
,
nil
}
// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。
// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id
// 就被覆盖。走 ent UpdateMany 保留 hooks。
func
(
r
*
channelMonitorRequestTemplateRepository
)
ApplyToMonitors
(
ctx
context
.
Context
,
id
int64
,
monitorIDs
[]
int64
)
(
int64
,
error
)
{
if
len
(
monitorIDs
)
==
0
{
return
0
,
nil
}
client
:=
clientFromContext
(
ctx
,
r
.
client
)
tpl
,
err
:=
client
.
ChannelMonitorRequestTemplate
.
Query
()
.
Where
(
channelmonitorrequesttemplate
.
IDEQ
(
id
))
.
Only
(
ctx
)
if
err
!=
nil
{
return
0
,
translatePersistenceError
(
err
,
service
.
ErrChannelMonitorTemplateNotFound
,
nil
)
}
updater
:=
client
.
ChannelMonitor
.
Update
()
.
Where
(
channelmonitor
.
TemplateIDEQ
(
id
),
channelmonitor
.
IDIn
(
monitorIDs
...
),
)
.
SetExtraHeaders
(
emptyHeadersIfNilRepo
(
tpl
.
ExtraHeaders
))
.
SetBodyOverrideMode
(
defaultBodyModeRepo
(
tpl
.
BodyOverrideMode
))
if
tpl
.
BodyOverride
!=
nil
{
updater
=
updater
.
SetBodyOverride
(
tpl
.
BodyOverride
)
}
else
{
updater
=
updater
.
ClearBodyOverride
()
}
affected
,
err
:=
updater
.
Save
(
ctx
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"apply template to monitors: %w"
,
err
)
}
return
int64
(
affected
),
nil
}
// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。
func
(
r
*
channelMonitorRequestTemplateRepository
)
CountAssociatedMonitors
(
ctx
context
.
Context
,
id
int64
)
(
int64
,
error
)
{
count
,
err
:=
r
.
client
.
ChannelMonitor
.
Query
()
.
Where
(
channelmonitor
.
TemplateIDEQ
(
id
))
.
Count
(
ctx
)
if
err
!=
nil
{
return
0
,
fmt
.
Errorf
(
"count monitors for template %d: %w"
,
id
,
err
)
}
return
int64
(
count
),
nil
}
// ListAssociatedMonitors 列出模板关联的所有监控简略字段。
// ORDER BY name 稳定输出方便前端展示。
func
(
r
*
channelMonitorRequestTemplateRepository
)
ListAssociatedMonitors
(
ctx
context
.
Context
,
id
int64
)
([]
*
service
.
AssociatedMonitorBrief
,
error
)
{
rows
,
err
:=
r
.
client
.
ChannelMonitor
.
Query
()
.
Where
(
channelmonitor
.
TemplateIDEQ
(
id
))
.
Order
(
dbent
.
Asc
(
channelmonitor
.
FieldName
))
.
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list associated monitors for template %d: %w"
,
id
,
err
)
}
out
:=
make
([]
*
service
.
AssociatedMonitorBrief
,
0
,
len
(
rows
))
for
_
,
row
:=
range
rows
{
out
=
append
(
out
,
&
service
.
AssociatedMonitorBrief
{
ID
:
row
.
ID
,
Name
:
row
.
Name
,
Provider
:
string
(
row
.
Provider
),
Enabled
:
row
.
Enabled
,
})
}
return
out
,
nil
}
// ---------- helpers ----------
func
entToServiceTemplate
(
row
*
dbent
.
ChannelMonitorRequestTemplate
)
*
service
.
ChannelMonitorRequestTemplate
{
if
row
==
nil
{
return
nil
}
headers
:=
row
.
ExtraHeaders
if
headers
==
nil
{
headers
=
map
[
string
]
string
{}
}
return
&
service
.
ChannelMonitorRequestTemplate
{
ID
:
row
.
ID
,
Name
:
row
.
Name
,
Provider
:
string
(
row
.
Provider
),
Description
:
row
.
Description
,
ExtraHeaders
:
headers
,
BodyOverrideMode
:
row
.
BodyOverrideMode
,
BodyOverride
:
row
.
BodyOverride
,
CreatedAt
:
row
.
CreatedAt
,
UpdatedAt
:
row
.
UpdatedAt
,
}
}
backend/internal/repository/wire.go
View file @
ac114738
...
...
@@ -89,6 +89,8 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository
,
NewTLSFingerprintProfileRepository
,
NewChannelRepository
,
NewChannelMonitorRepository
,
NewChannelMonitorRequestTemplateRepository
,
// Cache implementations
NewGatewayCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
ac114738
...
...
@@ -771,6 +771,9 @@ func TestAPIContracts(t *testing.T) {
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
...
...
@@ -943,6 +946,9 @@ func TestAPIContracts(t *testing.T) {
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"channel_monitor_enabled": true,
"channel_monitor_default_interval_seconds": 60,
"available_channels_enabled": false,
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
...
...
backend/internal/server/routes/admin.go
View file @
ac114738
...
...
@@ -88,6 +88,9 @@ func RegisterAdminRoutes(
// 渠道管理
registerChannelRoutes
(
admin
,
h
)
// 渠道监控
registerChannelMonitorRoutes
(
admin
,
h
)
}
}
...
...
@@ -567,3 +570,27 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels
.
DELETE
(
"/:id"
,
h
.
Admin
.
Channel
.
Delete
)
}
}
func
registerChannelMonitorRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
monitors
:=
admin
.
Group
(
"/channel-monitors"
)
{
monitors
.
GET
(
""
,
h
.
Admin
.
ChannelMonitor
.
List
)
monitors
.
POST
(
""
,
h
.
Admin
.
ChannelMonitor
.
Create
)
monitors
.
GET
(
"/:id"
,
h
.
Admin
.
ChannelMonitor
.
Get
)
monitors
.
PUT
(
"/:id"
,
h
.
Admin
.
ChannelMonitor
.
Update
)
monitors
.
DELETE
(
"/:id"
,
h
.
Admin
.
ChannelMonitor
.
Delete
)
monitors
.
POST
(
"/:id/run"
,
h
.
Admin
.
ChannelMonitor
.
Run
)
monitors
.
GET
(
"/:id/history"
,
h
.
Admin
.
ChannelMonitor
.
History
)
}
templates
:=
admin
.
Group
(
"/channel-monitor-templates"
)
{
templates
.
GET
(
""
,
h
.
Admin
.
ChannelMonitorTemplate
.
List
)
templates
.
POST
(
""
,
h
.
Admin
.
ChannelMonitorTemplate
.
Create
)
templates
.
GET
(
"/:id"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Get
)
templates
.
PUT
(
"/:id"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Update
)
templates
.
DELETE
(
"/:id"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Delete
)
templates
.
GET
(
"/:id/monitors"
,
h
.
Admin
.
ChannelMonitorTemplate
.
AssociatedMonitors
)
templates
.
POST
(
"/:id/apply"
,
h
.
Admin
.
ChannelMonitorTemplate
.
Apply
)
}
}
backend/internal/server/routes/user.go
View file @
ac114738
...
...
@@ -68,6 +68,12 @@ func RegisterUserRoutes(
groups
.
GET
(
"/rates"
,
h
.
APIKey
.
GetUserGroupRates
)
}
// 用户可用渠道(非管理员接口)
channels
:=
authenticated
.
Group
(
"/channels"
)
{
channels
.
GET
(
"/available"
,
h
.
AvailableChannel
.
List
)
}
// 使用记录
usage
:=
authenticated
.
Group
(
"/usage"
)
{
...
...
@@ -103,5 +109,12 @@ func RegisterUserRoutes(
subscriptions
.
GET
(
"/progress"
,
h
.
Subscription
.
GetProgress
)
subscriptions
.
GET
(
"/summary"
,
h
.
Subscription
.
GetSummary
)
}
// 渠道监控(用户只读)
monitors
:=
authenticated
.
Group
(
"/channel-monitors"
)
{
monitors
.
GET
(
""
,
h
.
ChannelMonitor
.
List
)
monitors
.
GET
(
"/:id/status"
,
h
.
ChannelMonitor
.
GetStatus
)
}
}
}
backend/internal/service/channel.go
View file @
ac114738
...
...
@@ -111,6 +111,18 @@ func (c *Channel) IsActive() bool {
return
c
.
Status
==
StatusActive
}
// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
// 作为 *Channel 的实体方法集中管理默认值,service 层只需在 Channel 进入内存
// (缓存装填、repo 读出)时调用一次,下游读路径就无需重复兜底。
func
(
c
*
Channel
)
normalizeBillingModelSource
()
{
if
c
==
nil
{
return
}
if
c
.
BillingModelSource
==
""
{
c
.
BillingModelSource
=
BillingModelSourceChannelMapped
}
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func
(
c
*
Channel
)
GetModelPricing
(
model
string
)
*
ChannelModelPricing
{
...
...
@@ -345,3 +357,209 @@ type ChannelUsageFields struct {
BillingModelSource
string
// 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain
string
// 映射链描述,如 "a→b→c"
}
// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
type
SupportedModel
struct
{
Name
string
// 用户侧模型名
Platform
string
// 所属平台
Pricing
*
ChannelModelPricing
// 定价详情(nil 表示未配置定价)
}
// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
const
wildcardSuffix
=
"*"
// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
//
// "claude-opus-*" → ("claude-opus-", true)
// "claude-opus-4" → ("claude-opus-4", false)
// "*" → ("", true)
//
// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
func
splitWildcardSuffix
(
pattern
string
)
(
prefix
string
,
isWildcard
bool
)
{
if
strings
.
HasSuffix
(
pattern
,
wildcardSuffix
)
{
return
strings
.
TrimSuffix
(
pattern
,
wildcardSuffix
),
true
}
return
pattern
,
false
}
// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
func
(
c
*
Channel
)
GetModelPricingByPlatform
(
platform
,
model
string
)
*
ChannelModelPricing
{
if
c
==
nil
{
return
nil
}
modelLower
:=
strings
.
ToLower
(
model
)
for
i
:=
range
c
.
ModelPricing
{
if
c
.
ModelPricing
[
i
]
.
Platform
!=
platform
{
continue
}
for
_
,
m
:=
range
c
.
ModelPricing
[
i
]
.
Models
{
if
strings
.
ToLower
(
m
)
==
modelLower
{
cp
:=
c
.
ModelPricing
[
i
]
.
Clone
()
return
&
cp
}
}
}
return
nil
}
// platformPricingIndex 是单个平台下定价信息的复合索引。
// 一次扫描即可同时支持精确查找(exact 分支)与有序遍历(wildcard 分支),
// 避免 SupportedModels 对每个平台重复扫描定价列表。
//
// byLower 与 names/originalCase 共享同一套去重规则:以 lower-case 模型名为 key,
// 首个命中保留其原始大小写。names 维持按定价行扫描顺序的稳定迭代。
type
platformPricingIndex
struct
{
byLower
map
[
string
]
*
ChannelModelPricing
// lowercased model name → pricing (Clone'd)
originalCase
map
[
string
]
string
// lowercased model name → original-case model name
names
[]
string
// priced model names in their ORIGINAL case, insertion-ordered, deduped case-insensitively (first wins)
}
// buildPricingIndex 对渠道的定价列表做一次扫描,按 platform 聚合为查找索引。
// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
// 通配符后缀条目(如 "claude-*")不被索引(它们是模式,不是具体模型名)。
// 同一平台中以大小写不敏感方式去重,先出现者保留原始大小写。
func
buildPricingIndex
(
pricings
[]
ChannelModelPricing
)
map
[
string
]
*
platformPricingIndex
{
idx
:=
make
(
map
[
string
]
*
platformPricingIndex
)
for
i
:=
range
pricings
{
p
:=
pricings
[
i
]
pidx
,
ok
:=
idx
[
p
.
Platform
]
if
!
ok
{
pidx
=
&
platformPricingIndex
{
byLower
:
make
(
map
[
string
]
*
ChannelModelPricing
),
originalCase
:
make
(
map
[
string
]
string
),
names
:
make
([]
string
,
0
),
}
idx
[
p
.
Platform
]
=
pidx
}
for
_
,
m
:=
range
p
.
Models
{
if
_
,
wild
:=
splitWildcardSuffix
(
m
);
wild
{
continue
}
lower
:=
strings
.
ToLower
(
m
)
if
_
,
exists
:=
pidx
.
byLower
[
lower
];
exists
{
continue
// 首个命中胜出(case-insensitive 去重后第一个定价 / 第一个原始大小写)
}
cp
:=
pricings
[
i
]
.
Clone
()
pidx
.
byLower
[
lower
]
=
&
cp
pidx
.
originalCase
[
lower
]
=
m
pidx
.
names
=
append
(
pidx
.
names
,
m
)
}
}
return
idx
}
// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
//
// 算法(mapping ∪ pricing 并联):
//
// - Pass A(mapping):遍历 ModelMapping
// - 精确 src → target:显示名 = src(用户视角),定价用 target 在同 platform 定价里查
// (mapping 改写后实际计费的是 target;这是用户感知的"实际花费")。
// target 为空或为通配符时退化为按 src 自查。
// - 通配符 src(如 "claude-3-*"):用同 platform 定价里前缀匹配的模型作为候选展开,
// 每个候选用自身定价(通配符场景一般是 passthrough,target 通常也是通配符)。
// - "*" 单独 mapping key 走通配符分支(前缀为空 → 全展开)。
// - Pass B(pricing-only):遍历 ModelPricing 中所有非通配符模型,对未在 Pass A 添加过的
// 补齐——显示名 = 定价模型名,定价 = 自身(这是关键修复:定价存在即代表渠道支持该模型,
// 即使没配映射)。
//
// 显示名命中定价时使用**定价的原始大小写**(定价是模型身份的事实来源)。
// 按 (Platform, Name) 稳定排序,按 (Platform, lowercase(Name)) 去重,先到者胜出。
//
// 注意:定价仅在 channel.ModelPricing 内查找——全局 LiteLLM 回落由调用方
// (`ChannelService.ListAvailable`)在合成展示数据时叠加。
func
(
c
*
Channel
)
SupportedModels
()
[]
SupportedModel
{
if
c
==
nil
{
return
nil
}
if
len
(
c
.
ModelMapping
)
==
0
&&
len
(
c
.
ModelPricing
)
==
0
{
return
nil
}
idx
:=
buildPricingIndex
(
c
.
ModelPricing
)
type
dedupKey
struct
{
platform
string
name
string
}
seen
:=
make
(
map
[
dedupKey
]
struct
{})
result
:=
make
([]
SupportedModel
,
0
)
// lookup 在 platform pricing index 中按精确名查定价,命中时返回定价大小写。
lookup
:=
func
(
pidx
*
platformPricingIndex
,
name
string
)
(
display
string
,
pricing
*
ChannelModelPricing
)
{
if
pidx
==
nil
||
name
==
""
{
return
name
,
nil
}
lower
:=
strings
.
ToLower
(
name
)
if
p
,
ok
:=
pidx
.
byLower
[
lower
];
ok
{
return
pidx
.
originalCase
[
lower
],
p
}
return
name
,
nil
}
add
:=
func
(
platform
,
displayName
string
,
pricing
*
ChannelModelPricing
)
{
key
:=
dedupKey
{
platform
:
platform
,
name
:
strings
.
ToLower
(
displayName
)}
if
_
,
ok
:=
seen
[
key
];
ok
{
return
}
seen
[
key
]
=
struct
{}{}
result
=
append
(
result
,
SupportedModel
{
Name
:
displayName
,
Platform
:
platform
,
Pricing
:
pricing
,
})
}
// Pass A:从 mapping 展开
for
platform
,
mapping
:=
range
c
.
ModelMapping
{
if
len
(
mapping
)
==
0
{
continue
}
pidx
:=
idx
[
platform
]
for
src
,
target
:=
range
mapping
{
prefix
,
isWild
:=
splitWildcardSuffix
(
src
)
if
isWild
{
if
pidx
==
nil
{
continue
}
prefixLower
:=
strings
.
ToLower
(
prefix
)
for
_
,
candidate
:=
range
pidx
.
names
{
if
strings
.
HasPrefix
(
strings
.
ToLower
(
candidate
),
prefixLower
)
{
display
,
pricing
:=
lookup
(
pidx
,
candidate
)
add
(
platform
,
display
,
pricing
)
}
}
continue
}
// 精确 mapping:定价按 target 查;target 缺失/通配则退化按 src 查
pricingKey
:=
target
if
pricingKey
==
""
{
pricingKey
=
src
}
if
_
,
targetWild
:=
splitWildcardSuffix
(
pricingKey
);
targetWild
{
pricingKey
=
src
}
_
,
pricing
:=
lookup
(
pidx
,
pricingKey
)
// 显示名优先用 src 在定价里的原始大小写(若 src 本身是个定价模型名)
displayName
,
_
:=
lookup
(
pidx
,
src
)
add
(
platform
,
displayName
,
pricing
)
}
}
// Pass B:从 pricing 补齐 mapping 未覆盖的具体模型(修复"定价存在但没配映射 → 不显示")
for
platform
,
pidx
:=
range
idx
{
for
_
,
name
:=
range
pidx
.
names
{
display
,
pricing
:=
lookup
(
pidx
,
name
)
add
(
platform
,
display
,
pricing
)
}
}
sort
.
SliceStable
(
result
,
func
(
i
,
j
int
)
bool
{
if
result
[
i
]
.
Platform
!=
result
[
j
]
.
Platform
{
return
result
[
i
]
.
Platform
<
result
[
j
]
.
Platform
}
return
result
[
i
]
.
Name
<
result
[
j
]
.
Name
})
return
result
}
backend/internal/service/channel_available.go
0 → 100644
View file @
ac114738
package
service
import
(
"context"
"fmt"
"sort"
"strings"
)
// AvailableGroupRef 渠道视图中关联分组的简要信息。
//
// 用户侧「可用渠道」页面据此展示:专属分组 vs 公开分组(IsExclusive)、
// 订阅 vs 标准(SubscriptionType)、默认倍率(RateMultiplier)。用户专属倍率
// 不在这里暴露,前端自己通过 /groups/rates 拉取,和 API 密钥页面保持一致。
type
AvailableGroupRef
struct
{
ID
int64
Name
string
Platform
string
SubscriptionType
string
RateMultiplier
float64
IsExclusive
bool
}
// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
// 关联的分组 + 推导出的支持模型列表(无通配符)。
type
AvailableChannel
struct
{
ID
int64
Name
string
Description
string
Status
string
BillingModelSource
string
RestrictModels
bool
Groups
[]
AvailableGroupRef
SupportedModels
[]
SupportedModel
}
// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
//
// 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。
// 对于渠道未配置定价的模型,进一步用 PricingService 的全局 LiteLLM 数据合成
// 一份展示用定价,让用户看到默认价格而非"未配置"。
//
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
// 的分组(已停用或删除)会被忽略。
//
// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast,
// 避免静默掩盖注入缺失。
func
(
s
*
ChannelService
)
ListAvailable
(
ctx
context
.
Context
)
([]
AvailableChannel
,
error
)
{
channels
,
err
:=
s
.
repo
.
ListAll
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list channels: %w"
,
err
)
}
groups
,
err
:=
s
.
groupRepo
.
ListActive
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list active groups: %w"
,
err
)
}
groupByID
:=
make
(
map
[
int64
]
AvailableGroupRef
,
len
(
groups
))
for
i
:=
range
groups
{
g
:=
groups
[
i
]
groupByID
[
g
.
ID
]
=
AvailableGroupRef
{
ID
:
g
.
ID
,
Name
:
g
.
Name
,
Platform
:
g
.
Platform
,
SubscriptionType
:
g
.
SubscriptionType
,
RateMultiplier
:
g
.
RateMultiplier
,
IsExclusive
:
g
.
IsExclusive
,
}
}
out
:=
make
([]
AvailableChannel
,
0
,
len
(
channels
))
for
i
:=
range
channels
{
ch
:=
&
channels
[
i
]
groups
:=
make
([]
AvailableGroupRef
,
0
,
len
(
ch
.
GroupIDs
))
for
_
,
gid
:=
range
ch
.
GroupIDs
{
if
ref
,
ok
:=
groupByID
[
gid
];
ok
{
groups
=
append
(
groups
,
ref
)
}
}
sort
.
SliceStable
(
groups
,
func
(
i
,
j
int
)
bool
{
return
groups
[
i
]
.
Name
<
groups
[
j
]
.
Name
})
ch
.
normalizeBillingModelSource
()
supported
:=
ch
.
SupportedModels
()
s
.
fillGlobalPricingFallback
(
supported
)
out
=
append
(
out
,
AvailableChannel
{
ID
:
ch
.
ID
,
Name
:
ch
.
Name
,
Description
:
ch
.
Description
,
Status
:
ch
.
Status
,
BillingModelSource
:
ch
.
BillingModelSource
,
RestrictModels
:
ch
.
RestrictModels
,
Groups
:
groups
,
SupportedModels
:
supported
,
})
}
sort
.
SliceStable
(
out
,
func
(
i
,
j
int
)
bool
{
return
strings
.
ToLower
(
out
[
i
]
.
Name
)
<
strings
.
ToLower
(
out
[
j
]
.
Name
)
})
return
out
,
nil
}
// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份
// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。
//
// 当 s.pricingService 为 nil(测试场景),跳过回落。
func
(
s
*
ChannelService
)
fillGlobalPricingFallback
(
models
[]
SupportedModel
)
{
if
s
.
pricingService
==
nil
{
return
}
for
i
:=
range
models
{
if
models
[
i
]
.
Pricing
!=
nil
{
continue
}
lp
:=
s
.
pricingService
.
GetModelPricing
(
models
[
i
]
.
Name
)
if
lp
==
nil
{
continue
}
models
[
i
]
.
Pricing
=
synthesizePricingFromLiteLLM
(
lp
)
}
}
// synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态,
// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到
// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。
//
// LiteLLM 中字段 0 视为未配置,不带入展示。
func
synthesizePricingFromLiteLLM
(
lp
*
LiteLLMModelPricing
)
*
ChannelModelPricing
{
if
lp
==
nil
{
return
nil
}
return
&
ChannelModelPricing
{
BillingMode
:
BillingModeToken
,
InputPrice
:
nonZeroPtr
(
lp
.
InputCostPerToken
),
OutputPrice
:
nonZeroPtr
(
lp
.
OutputCostPerToken
),
CacheWritePrice
:
nonZeroPtr
(
lp
.
CacheCreationInputTokenCost
),
CacheReadPrice
:
nonZeroPtr
(
lp
.
CacheReadInputTokenCost
),
ImageOutputPrice
:
nonZeroPtr
(
lp
.
OutputCostPerImageToken
),
}
}
func
nonZeroPtr
(
v
float64
)
*
float64
{
if
v
==
0
{
return
nil
}
return
&
v
}
backend/internal/service/channel_available_test.go
0 → 100644
View file @
ac114738
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。
// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
type
stubGroupRepoForAvailable
struct
{
activeGroups
[]
Group
listActiveErr
error
listActiveCalls
int
}
func
(
s
*
stubGroupRepoForAvailable
)
ListActive
(
ctx
context
.
Context
)
([]
Group
,
error
)
{
s
.
listActiveCalls
++
if
s
.
listActiveErr
!=
nil
{
return
nil
,
s
.
listActiveErr
}
return
s
.
activeGroups
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
Create
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
return
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
DeleteCascade
(
ctx
context
.
Context
,
id
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
List
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
ListWithFilters
(
ctx
context
.
Context
,
params
pagination
.
PaginationParams
,
platform
,
status
,
search
string
,
isExclusive
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
return
nil
,
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
ListActiveByPlatform
(
ctx
context
.
Context
,
platform
string
)
([]
Group
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
ExistsByName
(
ctx
context
.
Context
,
name
string
)
(
bool
,
error
)
{
return
false
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
GetAccountCount
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
int64
,
error
)
{
return
0
,
0
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
DeleteAccountGroupsByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
(
int64
,
error
)
{
return
0
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
GetAccountIDsByGroupIDs
(
ctx
context
.
Context
,
groupIDs
[]
int64
)
([]
int64
,
error
)
{
return
nil
,
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
BindAccountsToGroup
(
ctx
context
.
Context
,
groupID
int64
,
accountIDs
[]
int64
)
error
{
return
nil
}
func
(
s
*
stubGroupRepoForAvailable
)
UpdateSortOrders
(
ctx
context
.
Context
,
updates
[]
GroupSortOrderUpdate
)
error
{
return
nil
}
// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
// groupRepo 由参数决定。传入空 stub 表示「活跃分组列表为空」。
func
newAvailableChannelService
(
channels
[]
Channel
,
groupRepo
GroupRepository
)
*
ChannelService
{
repo
:=
&
mockChannelRepository
{
listAllFn
:
func
(
ctx
context
.
Context
)
([]
Channel
,
error
)
{
return
channels
,
nil
},
}
return
NewChannelService
(
repo
,
groupRepo
,
nil
,
nil
)
}
func
TestListAvailable_EmptyActiveGroups_NoGroupsAttached
(
t
*
testing
.
T
)
{
// 活跃分组列表为空时,渠道的 Groups 应为空切片,不报错。
channels
:=
[]
Channel
{{
ID
:
1
,
Name
:
"chA"
,
Status
:
StatusActive
,
GroupIDs
:
[]
int64
{
10
,
20
},
}}
svc
:=
newAvailableChannelService
(
channels
,
&
stubGroupRepoForAvailable
{})
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
out
,
1
)
require
.
Empty
(
t
,
out
[
0
]
.
Groups
)
}
func
TestListAvailable_InactiveGroupIDSilentlyDropped
(
t
*
testing
.
T
)
{
// 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
channels
:=
[]
Channel
{{
ID
:
1
,
Name
:
"chA"
,
Status
:
StatusActive
,
GroupIDs
:
[]
int64
{
1
,
99
},
}}
groupRepo
:=
&
stubGroupRepoForAvailable
{
activeGroups
:
[]
Group
{{
ID
:
1
,
Name
:
"g1"
,
Platform
:
"anthropic"
}},
}
svc
:=
newAvailableChannelService
(
channels
,
groupRepo
)
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
out
,
1
)
require
.
Len
(
t
,
out
[
0
]
.
Groups
,
1
)
require
.
Equal
(
t
,
int64
(
1
),
out
[
0
]
.
Groups
[
0
]
.
ID
)
}
func
TestListAvailable_SortedByName
(
t
*
testing
.
T
)
{
channels
:=
[]
Channel
{
{
ID
:
1
,
Name
:
"beta"
},
{
ID
:
2
,
Name
:
"Alpha"
},
{
ID
:
3
,
Name
:
"charlie"
},
}
svc
:=
newAvailableChannelService
(
channels
,
&
stubGroupRepoForAvailable
{})
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
out
,
3
)
require
.
Equal
(
t
,
"Alpha"
,
out
[
0
]
.
Name
)
require
.
Equal
(
t
,
"beta"
,
out
[
1
]
.
Name
)
require
.
Equal
(
t
,
"charlie"
,
out
[
2
]
.
Name
)
}
func
TestListAvailable_ListAllErrorPropagates
(
t
*
testing
.
T
)
{
// ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。
sentinel
:=
errors
.
New
(
"list-all-boom"
)
repo
:=
&
mockChannelRepository
{
listAllFn
:
func
(
ctx
context
.
Context
)
([]
Channel
,
error
)
{
return
nil
,
sentinel
},
}
groupRepo
:=
&
stubGroupRepoForAvailable
{}
svc
:=
NewChannelService
(
repo
,
groupRepo
,
nil
,
nil
)
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
Nil
(
t
,
out
)
require
.
ErrorIs
(
t
,
err
,
sentinel
)
require
.
Contains
(
t
,
err
.
Error
(),
"list channels"
,
"wrap 前缀缺失,可能 %w 被改为 %v"
)
require
.
Equal
(
t
,
0
,
groupRepo
.
listActiveCalls
,
"ListAll 失败后不应再调用 groupRepo.ListActive"
)
}
func
TestListAvailable_ListActiveErrorPropagates
(
t
*
testing
.
T
)
{
// groupRepo.ListActive 返回错误时 ListAvailable 应直接返回包装后的错误。
sentinel
:=
errors
.
New
(
"list-active-boom"
)
svc
:=
newAvailableChannelService
(
[]
Channel
{{
ID
:
1
,
Name
:
"chA"
}},
&
stubGroupRepoForAvailable
{
listActiveErr
:
sentinel
},
)
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
Nil
(
t
,
out
)
require
.
ErrorIs
(
t
,
err
,
sentinel
)
require
.
Contains
(
t
,
err
.
Error
(),
"list active groups"
,
"wrap 前缀缺失,可能 %w 被改为 %v"
)
}
func
TestListAvailable_DefaultsEmptyBillingModelSource
(
t
*
testing
.
T
)
{
// 渠道 BillingModelSource 为空时应回填为 BillingModelSourceChannelMapped,
// 显式值应原样保留(由 service 层统一处理,避免各 handler 重复默认逻辑)。
channels
:=
[]
Channel
{
{
ID
:
1
,
Name
:
"empty"
,
BillingModelSource
:
""
},
{
ID
:
2
,
Name
:
"explicit"
,
BillingModelSource
:
BillingModelSourceUpstream
},
}
svc
:=
newAvailableChannelService
(
channels
,
&
stubGroupRepoForAvailable
{})
out
,
err
:=
svc
.
ListAvailable
(
context
.
Background
())
require
.
NoError
(
t
,
err
)
require
.
Len
(
t
,
out
,
2
)
// 按 Name 查找,避免依赖排序副作用。
byName
:=
make
(
map
[
string
]
string
,
len
(
out
))
for
_
,
ch
:=
range
out
{
byName
[
ch
.
Name
]
=
ch
.
BillingModelSource
}
require
.
Equal
(
t
,
BillingModelSourceChannelMapped
,
byName
[
"empty"
])
require
.
Equal
(
t
,
BillingModelSourceUpstream
,
byName
[
"explicit"
])
}
backend/internal/service/channel_monitor_aggregator.go
0 → 100644
View file @
ac114738
package
service
import
(
"context"
"fmt"
"log/slog"
)
// 渠道监控聚合层:把 latest + availability 拼成 admin/user 视图所需的 summary / detail。
// 所有方法都遵守"失败仅日志,返回零值"的原则,避免 N+1 查询失败拖垮列表渲染。
// BatchMonitorStatusSummary 批量聚合多个监控的 latest + 7d 可用率(admin/user list 用,消除 N+1)。
// 失败时返回空 map,错误仅日志,不影响列表渲染。
//
// 参数:
// - ids: 要聚合的 monitor ID 列表
// - primaryByID: monitor ID -> primary model(用于读 7d 可用率与 latest 状态)
// - extrasByID: monitor ID -> extra models 列表(用于读 latest 状态填充 ExtraModels)
func
(
s
*
ChannelMonitorService
)
BatchMonitorStatusSummary
(
ctx
context
.
Context
,
ids
[]
int64
,
primaryByID
map
[
int64
]
string
,
extrasByID
map
[
int64
][]
string
,
)
map
[
int64
]
MonitorStatusSummary
{
out
:=
make
(
map
[
int64
]
MonitorStatusSummary
,
len
(
ids
))
if
len
(
ids
)
==
0
{
return
out
}
latestMap
,
err
:=
s
.
repo
.
ListLatestForMonitorIDs
(
ctx
,
ids
)
if
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: batch load latest failed"
,
"error"
,
err
)
latestMap
=
map
[
int64
][]
*
ChannelMonitorLatest
{}
}
availMap
,
err
:=
s
.
repo
.
ComputeAvailabilityForMonitors
(
ctx
,
ids
,
monitorAvailability7Days
)
if
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: batch compute availability failed"
,
"error"
,
err
)
availMap
=
map
[
int64
][]
*
ChannelMonitorAvailability
{}
}
for
_
,
id
:=
range
ids
{
out
[
id
]
=
buildStatusSummary
(
indexLatestByModel
(
latestMap
[
id
]),
indexAvailabilityByModel
(
availMap
[
id
]),
primaryByID
[
id
],
extrasByID
[
id
],
)
}
return
out
}
// ListUserView 用户只读视图:列出所有 enabled 监控的概览。
// 使用批量聚合接口避免 N+1:
//
// 1 次查 monitors;
// 1 次批量 latest(含 ping_latency_ms);
// 1 次批量 7d availability;
// 1 次批量 timeline(主模型最近 N 条)。
func
(
s
*
ChannelMonitorService
)
ListUserView
(
ctx
context
.
Context
)
([]
*
UserMonitorView
,
error
)
{
monitors
,
err
:=
s
.
repo
.
ListEnabled
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list enabled monitors: %w"
,
err
)
}
if
len
(
monitors
)
==
0
{
return
[]
*
UserMonitorView
{},
nil
}
ids
,
primaryByID
,
extrasByID
:=
collectMonitorIndexes
(
monitors
)
summaries
:=
s
.
BatchMonitorStatusSummary
(
ctx
,
ids
,
primaryByID
,
extrasByID
)
latestMap
:=
s
.
batchLatest
(
ctx
,
ids
)
timelineMap
:=
s
.
batchTimeline
(
ctx
,
ids
,
primaryByID
)
views
:=
make
([]
*
UserMonitorView
,
0
,
len
(
monitors
))
for
_
,
m
:=
range
monitors
{
primaryLatest
:=
pickLatest
(
latestMap
[
m
.
ID
],
m
.
PrimaryModel
)
views
=
append
(
views
,
buildUserViewFromSummary
(
m
,
summaries
[
m
.
ID
],
primaryLatest
,
timelineMap
[
m
.
ID
]))
}
return
views
,
nil
}
// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。
func
collectMonitorIndexes
(
monitors
[]
*
ChannelMonitor
)
([]
int64
,
map
[
int64
]
string
,
map
[
int64
][]
string
)
{
ids
:=
make
([]
int64
,
0
,
len
(
monitors
))
primaryByID
:=
make
(
map
[
int64
]
string
,
len
(
monitors
))
extrasByID
:=
make
(
map
[
int64
][]
string
,
len
(
monitors
))
for
_
,
m
:=
range
monitors
{
ids
=
append
(
ids
,
m
.
ID
)
primaryByID
[
m
.
ID
]
=
m
.
PrimaryModel
extrasByID
[
m
.
ID
]
=
m
.
ExtraModels
}
return
ids
,
primaryByID
,
extrasByID
}
// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。
func
(
s
*
ChannelMonitorService
)
batchLatest
(
ctx
context
.
Context
,
ids
[]
int64
)
map
[
int64
][]
*
ChannelMonitorLatest
{
latestMap
,
err
:=
s
.
repo
.
ListLatestForMonitorIDs
(
ctx
,
ids
)
if
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: user view batch latest failed"
,
"error"
,
err
)
return
map
[
int64
][]
*
ChannelMonitorLatest
{}
}
return
latestMap
}
// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。
func
(
s
*
ChannelMonitorService
)
batchTimeline
(
ctx
context
.
Context
,
ids
[]
int64
,
primaryByID
map
[
int64
]
string
,
)
map
[
int64
][]
*
ChannelMonitorHistoryEntry
{
timelineMap
,
err
:=
s
.
repo
.
ListRecentHistoryForMonitors
(
ctx
,
ids
,
primaryByID
,
monitorTimelineMaxPoints
)
if
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: user view batch timeline failed"
,
"error"
,
err
)
return
map
[
int64
][]
*
ChannelMonitorHistoryEntry
{}
}
return
timelineMap
}
// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。
func
pickLatest
(
rows
[]
*
ChannelMonitorLatest
,
model
string
)
*
ChannelMonitorLatest
{
if
model
==
""
{
return
nil
}
for
_
,
r
:=
range
rows
{
if
r
.
Model
==
model
{
return
r
}
}
return
nil
}
// GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。
// 不暴露 api_key。
func
(
s
*
ChannelMonitorService
)
GetUserDetail
(
ctx
context
.
Context
,
id
int64
)
(
*
UserMonitorDetail
,
error
)
{
m
,
err
:=
s
.
repo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
!
m
.
Enabled
{
return
nil
,
ErrChannelMonitorNotFound
}
latest
,
err
:=
s
.
repo
.
ListLatestPerModel
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list latest per model: %w"
,
err
)
}
availMap
,
err
:=
s
.
collectAvailabilityWindows
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
models
:=
mergeModelDetails
(
m
,
latest
,
availMap
)
return
&
UserMonitorDetail
{
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Provider
:
m
.
Provider
,
GroupName
:
m
.
GroupName
,
Models
:
models
,
},
nil
}
// collectAvailabilityWindows 一次性查询 7/15/30 天三个窗口,按模型组织。
func
(
s
*
ChannelMonitorService
)
collectAvailabilityWindows
(
ctx
context
.
Context
,
monitorID
int64
)
(
map
[
int
]
map
[
string
]
*
ChannelMonitorAvailability
,
error
)
{
out
:=
make
(
map
[
int
]
map
[
string
]
*
ChannelMonitorAvailability
,
3
)
windows
:=
[]
int
{
monitorAvailability7Days
,
monitorAvailability15Days
,
monitorAvailability30Days
}
for
_
,
w
:=
range
windows
{
rows
,
err
:=
s
.
repo
.
ComputeAvailability
(
ctx
,
monitorID
,
w
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"compute availability %dd: %w"
,
w
,
err
)
}
out
[
w
]
=
indexAvailabilityByModel
(
rows
)
}
return
out
,
nil
}
// ---------- 纯函数 helper(无 IO,可在 batch / 单 monitor / detail 路径复用)----------
// indexLatestByModel 把 latest 切片按 model 索引(小工具,避免在 hot path 重复写)。
func
indexLatestByModel
(
rows
[]
*
ChannelMonitorLatest
)
map
[
string
]
*
ChannelMonitorLatest
{
m
:=
make
(
map
[
string
]
*
ChannelMonitorLatest
,
len
(
rows
))
for
_
,
r
:=
range
rows
{
m
[
r
.
Model
]
=
r
}
return
m
}
// indexAvailabilityByModel 把 availability 切片按 model 索引。
func
indexAvailabilityByModel
(
rows
[]
*
ChannelMonitorAvailability
)
map
[
string
]
*
ChannelMonitorAvailability
{
m
:=
make
(
map
[
string
]
*
ChannelMonitorAvailability
,
len
(
rows
))
for
_
,
r
:=
range
rows
{
m
[
r
.
Model
]
=
r
}
return
m
}
// buildStatusSummary 由 latest + availability 字典构造 MonitorStatusSummary。
// 不做任何 IO,纯组装,便于在 batch 与单 monitor 路径复用。
func
buildStatusSummary
(
latestByModel
map
[
string
]
*
ChannelMonitorLatest
,
availByModel
map
[
string
]
*
ChannelMonitorAvailability
,
primary
string
,
extras
[]
string
,
)
MonitorStatusSummary
{
summary
:=
MonitorStatusSummary
{
ExtraModels
:
make
([]
ExtraModelStatus
,
0
,
len
(
extras
))}
if
primary
!=
""
{
if
l
,
ok
:=
latestByModel
[
primary
];
ok
{
summary
.
PrimaryStatus
=
l
.
Status
summary
.
PrimaryLatencyMs
=
l
.
LatencyMs
}
if
a
,
ok
:=
availByModel
[
primary
];
ok
{
summary
.
Availability7d
=
a
.
AvailabilityPct
}
}
for
_
,
model
:=
range
extras
{
entry
:=
ExtraModelStatus
{
Model
:
model
}
if
l
,
ok
:=
latestByModel
[
model
];
ok
{
entry
.
Status
=
l
.
Status
entry
.
LatencyMs
=
l
.
LatencyMs
}
summary
.
ExtraModels
=
append
(
summary
.
ExtraModels
,
entry
)
}
return
summary
}
// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。
// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。
func
buildUserViewFromSummary
(
m
*
ChannelMonitor
,
summary
MonitorStatusSummary
,
primaryLatest
*
ChannelMonitorLatest
,
timelineEntries
[]
*
ChannelMonitorHistoryEntry
,
)
*
UserMonitorView
{
view
:=
&
UserMonitorView
{
ID
:
m
.
ID
,
Name
:
m
.
Name
,
Provider
:
m
.
Provider
,
GroupName
:
m
.
GroupName
,
PrimaryModel
:
m
.
PrimaryModel
,
PrimaryStatus
:
summary
.
PrimaryStatus
,
PrimaryLatencyMs
:
summary
.
PrimaryLatencyMs
,
Availability7d
:
summary
.
Availability7d
,
ExtraModels
:
summary
.
ExtraModels
,
Timeline
:
buildTimelinePoints
(
timelineEntries
),
}
if
primaryLatest
!=
nil
{
view
.
PrimaryPingLatencyMs
=
primaryLatest
.
PingLatencyMs
}
return
view
}
// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。
func
buildTimelinePoints
(
entries
[]
*
ChannelMonitorHistoryEntry
)
[]
UserMonitorTimelinePoint
{
out
:=
make
([]
UserMonitorTimelinePoint
,
0
,
len
(
entries
))
for
_
,
e
:=
range
entries
{
out
=
append
(
out
,
UserMonitorTimelinePoint
{
Status
:
e
.
Status
,
LatencyMs
:
e
.
LatencyMs
,
PingLatencyMs
:
e
.
PingLatencyMs
,
CheckedAt
:
e
.
CheckedAt
,
})
}
return
out
}
// mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。
// 复用 indexLatestByModel,避免在多处重复写 build map 逻辑。
func
mergeModelDetails
(
m
*
ChannelMonitor
,
latest
[]
*
ChannelMonitorLatest
,
availMap
map
[
int
]
map
[
string
]
*
ChannelMonitorAvailability
,
)
[]
ModelDetail
{
all
:=
append
([]
string
{
m
.
PrimaryModel
},
m
.
ExtraModels
...
)
latestByModel
:=
indexLatestByModel
(
latest
)
out
:=
make
([]
ModelDetail
,
0
,
len
(
all
))
for
_
,
model
:=
range
all
{
d
:=
ModelDetail
{
Model
:
model
}
if
l
,
ok
:=
latestByModel
[
model
];
ok
{
d
.
LatestStatus
=
l
.
Status
d
.
LatestLatencyMs
=
l
.
LatencyMs
}
if
a
,
ok
:=
availMap
[
monitorAvailability7Days
][
model
];
ok
{
d
.
Availability7d
=
a
.
AvailabilityPct
d
.
AvgLatency7dMs
=
a
.
AvgLatencyMs
}
if
a
,
ok
:=
availMap
[
monitorAvailability15Days
][
model
];
ok
{
d
.
Availability15d
=
a
.
AvailabilityPct
}
if
a
,
ok
:=
availMap
[
monitorAvailability30Days
][
model
];
ok
{
d
.
Availability30d
=
a
.
AvailabilityPct
}
out
=
append
(
out
,
d
)
}
return
out
}
backend/internal/service/channel_monitor_challenge.go
0 → 100644
View file @
ac114738
package
service
import
(
"fmt"
"math/rand/v2"
"regexp"
"strconv"
)
// monitorChallengePromptTemplate 1:1 复刻 BingZi-233/check-cx 的 few-shot 模板。
const
monitorChallengePromptTemplate
=
`Calculate and respond with ONLY the number, nothing else.
Q: 3 + 5 = ?
A: 8
Q: 12 - 7 = ?
A: 5
Q: %d %s %d = ?
A:`
// monitorChallengeNumberRegex 提取响应中的所有整数(含负号)。
var
monitorChallengeNumberRegex
=
regexp
.
MustCompile
(
`-?\d+`
)
// monitorChallenge 一次 challenge 的 prompt + 期望答案。
type
monitorChallenge
struct
{
Prompt
string
Expected
string
}
// generateChallenge 生成一次随机算术 challenge:
// - 随机两个 [monitorChallengeMin, monitorChallengeMax] 整数
// - 50% 加 / 50% 减;减法用 max - min 保证非负
// - 渲染 few-shot 模板
//
// 不强求加密随机:math/rand/v2 足够分散,避免 crypto/rand 的开销。
func
generateChallenge
()
monitorChallenge
{
a
:=
randIntInRange
(
monitorChallengeMin
,
monitorChallengeMax
)
b
:=
randIntInRange
(
monitorChallengeMin
,
monitorChallengeMax
)
if
rand
.
IntN
(
2
)
==
0
{
//nolint:gosec // 仅用于生成测试问题,无安全影响
// 加法
return
monitorChallenge
{
Prompt
:
fmt
.
Sprintf
(
monitorChallengePromptTemplate
,
a
,
"+"
,
b
),
Expected
:
strconv
.
Itoa
(
a
+
b
),
}
}
// 减法,保证非负
hi
,
lo
:=
a
,
b
if
lo
>
hi
{
hi
,
lo
=
lo
,
hi
}
return
monitorChallenge
{
Prompt
:
fmt
.
Sprintf
(
monitorChallengePromptTemplate
,
hi
,
"-"
,
lo
),
Expected
:
strconv
.
Itoa
(
hi
-
lo
),
}
}
// randIntInRange 返回 [min, max] 闭区间的随机整数。
func
randIntInRange
(
minVal
,
maxVal
int
)
int
{
if
maxVal
<=
minVal
{
return
minVal
}
return
minVal
+
rand
.
IntN
(
maxVal
-
minVal
+
1
)
//nolint:gosec
}
// validateChallenge 在响应文本中查找 expected 整数答案,返回是否通过校验。
func
validateChallenge
(
responseText
,
expected
string
)
bool
{
if
responseText
==
""
||
expected
==
""
{
return
false
}
matches
:=
monitorChallengeNumberRegex
.
FindAllString
(
responseText
,
-
1
)
for
_
,
m
:=
range
matches
{
if
m
==
expected
{
return
true
}
}
return
false
}
backend/internal/service/channel_monitor_checker.go
0 → 100644
View file @
ac114738
package
service
import
(
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/tidwall/gjson"
)
// monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。
// 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。
var
monitorHTTPClient
=
newSSRFSafeHTTPClient
(
monitorRequestTimeout
)
// monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。
var
monitorPingHTTPClient
=
newSSRFSafeHTTPClient
(
monitorPingTimeout
)
// newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。
// 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。
func
newSSRFSafeHTTPClient
(
timeout
time
.
Duration
)
*
http
.
Client
{
tr
:=
&
http
.
Transport
{
DialContext
:
safeDialContext
,
ForceAttemptHTTP2
:
true
,
MaxIdleConns
:
16
,
IdleConnTimeout
:
monitorIdleConnTimeout
,
TLSHandshakeTimeout
:
monitorTLSHandshakeTimeout
,
ResponseHeaderTimeout
:
monitorResponseHeaderTimeout
,
}
return
&
http
.
Client
{
Timeout
:
timeout
,
Transport
:
tr
}
}
// CheckOptions 承载一次检测的自定义入参。
// 所有字段都是可选(零值即等价于"用默认行为")。
type
CheckOptions
struct
{
// ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
ExtraHeaders
map
[
string
]
string
// BodyOverrideMode: off | merge | replace
BodyOverrideMode
string
// BodyOverride 在 merge 模式下做浅合并(key 命中黑名单时静默丢弃),
// 在 replace 模式下直接当作完整 body。
BodyOverride
map
[
string
]
any
}
// runCheckForModel 对单个 (provider, model) 做一次完整检测。
// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。
//
// opts 承载模板 / 监控快照带来的自定义配置。nil 等同于 "off + 无 extra headers"。
func
runCheckForModel
(
ctx
context
.
Context
,
provider
,
endpoint
,
apiKey
,
model
string
,
opts
*
CheckOptions
)
*
CheckResult
{
res
:=
&
CheckResult
{
Model
:
model
,
Status
:
MonitorStatusError
,
CheckedAt
:
time
.
Now
(),
}
challenge
:=
generateChallenge
()
mode
:=
bodyOverrideMode
(
opts
)
start
:=
time
.
Now
()
respText
,
rawBody
,
statusCode
,
err
:=
callProvider
(
ctx
,
provider
,
endpoint
,
apiKey
,
model
,
challenge
.
Prompt
,
opts
)
latency
:=
time
.
Since
(
start
)
latencyMs
:=
int
(
latency
/
time
.
Millisecond
)
res
.
LatencyMs
=
&
latencyMs
if
err
!=
nil
{
res
.
Status
=
MonitorStatusError
res
.
Message
=
truncateMessage
(
sanitizeErrorMessage
(
err
.
Error
()))
return
res
}
if
statusCode
<
200
||
statusCode
>=
300
{
// 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空,
// 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。
res
.
Status
=
MonitorStatusError
bodySnippet
:=
truncateForErrorBody
(
rawBody
)
res
.
Message
=
truncateMessage
(
sanitizeErrorMessage
(
fmt
.
Sprintf
(
"upstream HTTP %d: %s"
,
statusCode
,
bodySnippet
)))
return
res
}
// Replace 模式:跳过 challenge 校验(用户 body 是静态的,challenge 没法嵌入)。
// 改用「HTTP 2xx + 响应文本(adapter.textPath 抽取)非空」作为 operational 判定。
// 响应文本为空则降级为 failed(视为上游回了 200 但没实际内容)。
if
mode
==
MonitorBodyOverrideModeReplace
{
if
strings
.
TrimSpace
(
respText
)
==
""
{
res
.
Status
=
MonitorStatusFailed
res
.
Message
=
truncateMessage
(
"replace-mode: upstream returned 2xx with empty text"
)
return
res
}
return
finalizeOperationalOrDegraded
(
res
,
latency
,
latencyMs
)
}
if
!
validateChallenge
(
respText
,
challenge
.
Expected
)
{
res
.
Status
=
MonitorStatusFailed
res
.
Message
=
truncateMessage
(
sanitizeErrorMessage
(
fmt
.
Sprintf
(
"challenge mismatch (expected %s, got %q)"
,
challenge
.
Expected
,
respText
)))
return
res
}
return
finalizeOperationalOrDegraded
(
res
,
latency
,
latencyMs
)
}
// finalizeOperationalOrDegraded 负责走到最后一步的 operational/degraded 判定。
// 拆出来是为了让 runCheckForModel 不超过 30 行。
func
finalizeOperationalOrDegraded
(
res
*
CheckResult
,
latency
time
.
Duration
,
latencyMs
int
)
*
CheckResult
{
if
latency
>=
monitorDegradedThreshold
{
res
.
Status
=
MonitorStatusDegraded
res
.
Message
=
truncateMessage
(
fmt
.
Sprintf
(
"slow response: %dms"
,
latencyMs
))
return
res
}
res
.
Status
=
MonitorStatusOperational
return
res
}
// bodyOverrideMode 归一取 opts.BodyOverrideMode,nil opts / 空串都视为 off。
func
bodyOverrideMode
(
opts
*
CheckOptions
)
string
{
if
opts
==
nil
||
opts
.
BodyOverrideMode
==
""
{
return
MonitorBodyOverrideModeOff
}
return
opts
.
BodyOverrideMode
}
// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。
// 失败时返回 nil(不影响主状态判定)。
func
pingEndpointOrigin
(
ctx
context
.
Context
,
endpoint
string
)
*
int
{
origin
,
err
:=
extractOrigin
(
endpoint
)
if
err
!=
nil
||
origin
==
""
{
return
nil
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodHead
,
origin
,
nil
)
if
err
!=
nil
{
return
nil
}
start
:=
time
.
Now
()
resp
,
err
:=
monitorPingHTTPClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
_
,
_
=
io
.
Copy
(
io
.
Discard
,
io
.
LimitReader
(
resp
.
Body
,
monitorPingDiscardMaxBytes
))
ms
:=
int
(
time
.
Since
(
start
)
/
time
.
Millisecond
)
return
&
ms
}
// providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事:
// - 拼出请求路径(含 model 占位)
// - 序列化请求体
// - 构造鉴权头
// - 从响应 JSON 中按 path 提取文本(gjson path)
//
// 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。
type
providerAdapter
struct
{
buildPath
func
(
model
string
)
string
buildBody
func
(
model
,
prompt
string
)
([]
byte
,
error
)
buildHeaders
func
(
apiKey
string
)
map
[
string
]
string
textPath
string
// gjson 提取响应文本的 path
}
// providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。
//
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var
providerAdapters
=
map
[
string
]
providerAdapter
{
MonitorProviderOpenAI
:
{
buildPath
:
func
(
string
)
string
{
return
providerOpenAIPath
},
buildBody
:
func
(
model
,
prompt
string
)
([]
byte
,
error
)
{
return
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
model
,
"messages"
:
[]
map
[
string
]
string
{{
"role"
:
"user"
,
"content"
:
prompt
}},
"max_tokens"
:
monitorChallengeMaxTokens
,
"stream"
:
false
,
})
},
buildHeaders
:
func
(
apiKey
string
)
map
[
string
]
string
{
return
map
[
string
]
string
{
"Authorization"
:
"Bearer "
+
apiKey
}
},
textPath
:
"choices.0.message.content"
,
},
MonitorProviderAnthropic
:
{
buildPath
:
func
(
string
)
string
{
return
providerAnthropicPath
},
buildBody
:
func
(
model
,
prompt
string
)
([]
byte
,
error
)
{
return
json
.
Marshal
(
map
[
string
]
any
{
"model"
:
model
,
"messages"
:
[]
map
[
string
]
string
{{
"role"
:
"user"
,
"content"
:
prompt
}},
"max_tokens"
:
monitorChallengeMaxTokens
,
})
},
buildHeaders
:
func
(
apiKey
string
)
map
[
string
]
string
{
return
map
[
string
]
string
{
"x-api-key"
:
apiKey
,
"anthropic-version"
:
monitorAnthropicAPIVersion
,
}
},
textPath
:
"content.0.text"
,
},
MonitorProviderGemini
:
{
// Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent
buildPath
:
func
(
model
string
)
string
{
return
fmt
.
Sprintf
(
providerGeminiPathTemplate
,
model
)
},
buildBody
:
func
(
_
,
prompt
string
)
([]
byte
,
error
)
{
return
json
.
Marshal
(
map
[
string
]
any
{
"contents"
:
[]
map
[
string
]
any
{
{
"parts"
:
[]
map
[
string
]
any
{{
"text"
:
prompt
}}},
},
"generationConfig"
:
map
[
string
]
any
{
"maxOutputTokens"
:
monitorChallengeMaxTokens
},
})
},
// 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。
buildHeaders
:
func
(
apiKey
string
)
map
[
string
]
string
{
return
map
[
string
]
string
{
"x-goog-api-key"
:
apiKey
}
},
textPath
:
"candidates.0.content.parts.0.text"
,
},
}
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
func
isSupportedProvider
(
p
string
)
bool
{
_
,
ok
:=
providerAdapters
[
p
]
return
ok
}
// callProvider 通过 providerAdapters 分发到具体实现。
// opts 承载用户的自定义 headers / body 覆盖(可为 nil)。
//
// 返回值:
// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
// - status: HTTP 状态码
// - err: 网络 / 序列化错误
func
callProvider
(
ctx
context
.
Context
,
provider
,
endpoint
,
apiKey
,
model
,
prompt
string
,
opts
*
CheckOptions
)
(
extractedText
,
rawBody
string
,
status
int
,
err
error
)
{
adapter
,
ok
:=
providerAdapters
[
provider
]
if
!
ok
{
return
""
,
""
,
0
,
fmt
.
Errorf
(
"unsupported provider %q"
,
provider
)
}
body
,
err
:=
buildRequestBody
(
adapter
,
provider
,
model
,
prompt
,
opts
)
if
err
!=
nil
{
return
""
,
""
,
0
,
err
}
headers
:=
mergeHeaders
(
adapter
.
buildHeaders
(
apiKey
),
opts
)
full
:=
joinURL
(
endpoint
,
adapter
.
buildPath
(
model
))
respBytes
,
status
,
err
:=
postRawJSON
(
ctx
,
full
,
body
,
headers
)
if
err
!=
nil
{
return
""
,
""
,
status
,
err
}
return
gjson
.
GetBytes
(
respBytes
,
adapter
.
textPath
)
.
String
(),
string
(
respBytes
),
status
,
nil
}
// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
func
mergeHeaders
(
base
map
[
string
]
string
,
opts
*
CheckOptions
)
map
[
string
]
string
{
if
opts
==
nil
||
len
(
opts
.
ExtraHeaders
)
==
0
{
return
base
}
out
:=
make
(
map
[
string
]
string
,
len
(
base
)
+
len
(
opts
.
ExtraHeaders
))
for
k
,
v
:=
range
base
{
out
[
k
]
=
v
}
for
k
,
v
:=
range
opts
.
ExtraHeaders
{
if
IsForbiddenHeaderName
(
k
)
{
continue
}
out
[
k
]
=
v
}
return
out
}
// buildRequestBody 根据 body_override_mode 构造请求 body。
//
// - off: adapter 默认 body
// - merge: adapter 默认 body 与 BodyOverride 浅合并;BodyOverride 中命中
// bodyMergeKeyDenyList[provider] 的 key 会被静默丢弃,避免破坏 challenge / model 路由
// - replace: 直接 marshal BodyOverride 作为完整 body
//
// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
func
buildRequestBody
(
adapter
providerAdapter
,
provider
,
model
,
prompt
string
,
opts
*
CheckOptions
)
([]
byte
,
error
)
{
mode
:=
bodyOverrideMode
(
opts
)
if
mode
==
MonitorBodyOverrideModeReplace
{
if
opts
==
nil
||
len
(
opts
.
BodyOverride
)
==
0
{
return
nil
,
fmt
.
Errorf
(
"replace mode: body_override is empty"
)
}
body
,
err
:=
json
.
Marshal
(
opts
.
BodyOverride
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal body_override (replace): %w"
,
err
)
}
return
body
,
nil
}
defaultBody
,
err
:=
adapter
.
buildBody
(
model
,
prompt
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal default body: %w"
,
err
)
}
if
mode
!=
MonitorBodyOverrideModeMerge
||
opts
==
nil
||
len
(
opts
.
BodyOverride
)
==
0
{
return
defaultBody
,
nil
}
var
defaultMap
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
defaultBody
,
&
defaultMap
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal default body for merge: %w"
,
err
)
}
deny
:=
bodyMergeKeyDenyList
[
provider
]
for
k
,
v
:=
range
opts
.
BodyOverride
{
if
deny
[
k
]
{
continue
}
defaultMap
[
k
]
=
v
}
merged
,
err
:=
json
.
Marshal
(
defaultMap
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshal merged body: %w"
,
err
)
}
return
merged
,
nil
}
// bodyMergeKeyDenyList 在 merge 模式下,禁止用户覆盖这些 provider-specific 的关键字段。
// 思路抄 check-cx 的 EXCLUDED_METADATA_KEYS:保护 challenge / model 路由不被用户误伤。
// 用户想动这些字段就用 replace 模式(已知会跳 challenge 校验)。
//
//nolint:gochecknoglobals // 静态查表,初始化后不变。
var
bodyMergeKeyDenyList
=
map
[
string
]
map
[
string
]
bool
{
MonitorProviderOpenAI
:
{
"model"
:
true
,
"messages"
:
true
,
"stream"
:
true
},
MonitorProviderAnthropic
:
{
"model"
:
true
,
"messages"
:
true
},
MonitorProviderGemini
:
{
"contents"
:
true
},
}
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。
func
postRawJSON
(
ctx
context
.
Context
,
fullURL
string
,
payload
[]
byte
,
headers
map
[
string
]
string
)
([]
byte
,
int
,
error
)
{
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fullURL
,
bytes
.
NewReader
(
payload
))
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"build request: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
Header
.
Set
(
"Accept"
,
"application/json"
)
for
k
,
v
:=
range
headers
{
req
.
Header
.
Set
(
k
,
v
)
}
resp
,
err
:=
monitorHTTPClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"do request: %w"
,
err
)
}
defer
func
()
{
_
=
resp
.
Body
.
Close
()
}()
respBody
,
err
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
monitorResponseMaxBytes
))
if
err
!=
nil
{
return
nil
,
resp
.
StatusCode
,
fmt
.
Errorf
(
"read body: %w"
,
err
)
}
return
respBody
,
resp
.
StatusCode
,
nil
}
// joinURL 把 base origin 与 path 拼成完整 URL。
// 容忍 base 末尾有/无斜杠,path 必带前导斜杠。
func
joinURL
(
base
,
path
string
)
string
{
base
=
strings
.
TrimRight
(
base
,
"/"
)
if
!
strings
.
HasPrefix
(
path
,
"/"
)
{
path
=
"/"
+
path
}
return
base
+
path
}
// extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。
func
extractOrigin
(
endpoint
string
)
(
string
,
error
)
{
u
,
err
:=
url
.
Parse
(
endpoint
)
if
err
!=
nil
{
return
""
,
err
}
if
u
.
Scheme
==
""
||
u
.
Host
==
""
{
return
""
,
errors
.
New
(
"endpoint missing scheme or host"
)
}
return
u
.
Scheme
+
"://"
+
u
.
Host
,
nil
}
// monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数:
// key / api_key / api-key / access_token / token / authorization / x-api-key。
// 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。
var
monitorSensitiveQueryParamRegex
=
regexp
.
MustCompile
(
`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`
)
// monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。
// 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。
var
monitorAPIKeyPatterns
=
[]
struct
{
pattern
*
regexp
.
Regexp
replace
string
}{
// Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx
{
regexp
.
MustCompile
(
`sk-ant-[A-Za-z0-9_-]{20,}`
),
"sk-ant-***REDACTED***"
},
// OpenAI / Anthropic 通用 sk-: sk-xxxxxxx
{
regexp
.
MustCompile
(
`sk-[A-Za-z0-9-]{20,}`
),
"sk-***REDACTED***"
},
// Gemini / Google API Key:固定前缀 + 35 位
{
regexp
.
MustCompile
(
`AIza[A-Za-z0-9_-]{35}`
),
"AIza***REDACTED***"
},
// JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature
{
regexp
.
MustCompile
(
`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`
),
"eyJ***REDACTED.JWT***"
},
}
// sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。
// 处理两类来源:
// 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL)
// 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片
//
// 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广,
// 监控模块独立维护,避免互相耦合。
func
sanitizeErrorMessage
(
msg
string
)
string
{
if
msg
==
""
{
return
msg
}
msg
=
monitorSensitiveQueryParamRegex
.
ReplaceAllString
(
msg
,
`${1}REDACTED`
)
for
_
,
p
:=
range
monitorAPIKeyPatterns
{
msg
=
p
.
pattern
.
ReplaceAllString
(
msg
,
p
.
replace
)
}
return
msg
}
// truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。
func
truncateMessage
(
msg
string
)
string
{
if
len
(
msg
)
<=
monitorMessageMaxBytes
{
return
msg
}
const
ellipsis
=
"...(truncated)"
cutoff
:=
monitorMessageMaxBytes
-
len
(
ellipsis
)
if
cutoff
<
0
{
cutoff
=
0
}
return
msg
[
:
cutoff
]
+
ellipsis
}
// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内,
// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。
// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。
func
truncateForErrorBody
(
body
string
)
string
{
body
=
strings
.
Join
(
strings
.
Fields
(
body
),
" "
)
if
len
(
body
)
<=
monitorErrorBodySnippetMaxBytes
{
return
body
}
const
ellipsis
=
"...(body truncated)"
cutoff
:=
monitorErrorBodySnippetMaxBytes
-
len
(
ellipsis
)
if
cutoff
<
0
{
cutoff
=
0
}
return
body
[
:
cutoff
]
+
ellipsis
}
backend/internal/service/channel_monitor_checker_body_test.go
0 → 100644
View file @
ac114738
//go:build unit
package
service
import
(
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client,
// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。
func
swapMonitorHTTPClient
(
t
*
testing
.
T
)
{
t
.
Helper
()
orig
:=
monitorHTTPClient
monitorHTTPClient
=
&
http
.
Client
{
Timeout
:
5
*
time
.
Second
}
t
.
Cleanup
(
func
()
{
monitorHTTPClient
=
orig
})
}
// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。
type
captureHandler
struct
{
lastBody
map
[
string
]
any
lastHeaders
http
.
Header
respondText
string
// 写到 Anthropic content[0].text 里(校验用)
status
int
}
func
(
h
*
captureHandler
)
ServeHTTP
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
h
.
lastHeaders
=
r
.
Header
.
Clone
()
defer
func
()
{
_
=
r
.
Body
.
Close
()
}()
var
parsed
map
[
string
]
any
_
=
json
.
NewDecoder
(
r
.
Body
)
.
Decode
(
&
parsed
)
h
.
lastBody
=
parsed
if
h
.
status
==
0
{
h
.
status
=
200
}
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
w
.
WriteHeader
(
h
.
status
)
// 构造 Anthropic 格式的响应:content[0].text = h.respondText
_
=
json
.
NewEncoder
(
w
)
.
Encode
(
map
[
string
]
any
{
"content"
:
[]
map
[
string
]
any
{
{
"type"
:
"text"
,
"text"
:
h
.
respondText
},
},
})
}
func
setupFakeAnthropic
(
t
*
testing
.
T
,
handler
*
captureHandler
)
string
{
t
.
Helper
()
swapMonitorHTTPClient
(
t
)
srv
:=
httptest
.
NewServer
(
handler
)
t
.
Cleanup
(
srv
.
Close
)
return
srv
.
URL
}
func
TestRunCheckForModel_OffMode_PreservesDefaultBody
(
t
*
testing
.
T
)
{
h
:=
&
captureHandler
{
respondText
:
"the answer is 42"
}
endpoint
:=
setupFakeAnthropic
(
t
,
h
)
// 跑一次 off 模式(opts=nil),确认默认 body 行为未变
_
=
runCheckForModel
(
context
.
Background
(),
MonitorProviderAnthropic
,
endpoint
,
"sk-fake"
,
"claude-x"
,
nil
)
if
h
.
lastBody
[
"model"
]
!=
"claude-x"
{
t
.
Errorf
(
"default body should contain model=claude-x, got %v"
,
h
.
lastBody
[
"model"
])
}
if
_
,
ok
:=
h
.
lastBody
[
"messages"
];
!
ok
{
t
.
Error
(
"default body should contain messages"
)
}
if
h
.
lastHeaders
.
Get
(
"x-api-key"
)
!=
"sk-fake"
{
t
.
Errorf
(
"expected adapter's x-api-key header, got %q"
,
h
.
lastHeaders
.
Get
(
"x-api-key"
))
}
}
func
TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects
(
t
*
testing
.
T
)
{
h
:=
&
captureHandler
{
respondText
:
"the answer is 42"
}
endpoint
:=
setupFakeAnthropic
(
t
,
h
)
opts
:=
&
CheckOptions
{
BodyOverrideMode
:
MonitorBodyOverrideModeMerge
,
BodyOverride
:
map
[
string
]
any
{
"system"
:
"You are Claude Code..."
,
"max_tokens"
:
float64
(
999
),
// 应该覆盖默认 50
"model"
:
"hacked-model"
,
// 应该被黑名单挡住,保留原 model
"messages"
:
[]
any
{},
// 同上,被挡
},
ExtraHeaders
:
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/1.0"
,
"Content-Length"
:
"999"
,
// 黑名单
"x-custom"
:
"ok"
,
},
}
_
=
runCheckForModel
(
context
.
Background
(),
MonitorProviderAnthropic
,
endpoint
,
"sk-fake"
,
"claude-x"
,
opts
)
if
h
.
lastBody
[
"system"
]
!=
"You are Claude Code..."
{
t
.
Errorf
(
"merge mode should inject system, got %v"
,
h
.
lastBody
[
"system"
])
}
// max_tokens 覆盖生效
if
mt
,
ok
:=
h
.
lastBody
[
"max_tokens"
]
.
(
float64
);
!
ok
||
mt
!=
999
{
t
.
Errorf
(
"merge mode should override max_tokens to 999, got %v"
,
h
.
lastBody
[
"max_tokens"
])
}
// model 在黑名单 — 应该保留默认值
if
h
.
lastBody
[
"model"
]
!=
"claude-x"
{
t
.
Errorf
(
"model should be protected by deny list, got %v"
,
h
.
lastBody
[
"model"
])
}
// messages 在黑名单 — 应该保留默认值(非空)
msgs
,
_
:=
h
.
lastBody
[
"messages"
]
.
([]
any
)
if
len
(
msgs
)
==
0
{
t
.
Error
(
"messages should be protected by deny list (kept default, non-empty)"
)
}
// header 合并
if
h
.
lastHeaders
.
Get
(
"User-Agent"
)
!=
"claude-cli/1.0"
{
t
.
Errorf
(
"extra User-Agent should override, got %q"
,
h
.
lastHeaders
.
Get
(
"User-Agent"
))
}
if
h
.
lastHeaders
.
Get
(
"x-custom"
)
!=
"ok"
{
t
.
Errorf
(
"extra custom header should be present, got %q"
,
h
.
lastHeaders
.
Get
(
"x-custom"
))
}
// Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。
// 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。
}
func
TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped
(
t
*
testing
.
T
)
{
// replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里,
// 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational
h
:=
&
captureHandler
{
respondText
:
"any non-empty text"
}
endpoint
:=
setupFakeAnthropic
(
t
,
h
)
userBody
:=
map
[
string
]
any
{
"model"
:
"user-forced-model"
,
"messages"
:
[]
any
{
map
[
string
]
any
{
"role"
:
"user"
,
"content"
:
"hi"
}},
"max_tokens"
:
float64
(
10
),
"system"
:
"You are someone else"
,
}
opts
:=
&
CheckOptions
{
BodyOverrideMode
:
MonitorBodyOverrideModeReplace
,
BodyOverride
:
userBody
,
}
res
:=
runCheckForModel
(
context
.
Background
(),
MonitorProviderAnthropic
,
endpoint
,
"sk-fake"
,
"claude-x"
,
opts
)
// 请求 body = 用户提供的原样
if
h
.
lastBody
[
"model"
]
!=
"user-forced-model"
{
t
.
Errorf
(
"replace mode should use user's model, got %v"
,
h
.
lastBody
[
"model"
])
}
if
h
.
lastBody
[
"system"
]
!=
"You are someone else"
{
t
.
Errorf
(
"replace mode should use user's system, got %v"
,
h
.
lastBody
[
"system"
])
}
// challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational
if
res
.
Status
!=
MonitorStatusOperational
{
t
.
Errorf
(
"replace mode with 2xx + non-empty text should be operational, got status=%s message=%q"
,
res
.
Status
,
res
.
Message
)
}
}
func
TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed
(
t
*
testing
.
T
)
{
h
:=
&
captureHandler
{
respondText
:
""
}
// 上游 200 但 content[0].text 为空
endpoint
:=
setupFakeAnthropic
(
t
,
h
)
opts
:=
&
CheckOptions
{
BodyOverrideMode
:
MonitorBodyOverrideModeReplace
,
BodyOverride
:
map
[
string
]
any
{
"model"
:
"x"
,
"messages"
:
[]
any
{}},
}
res
:=
runCheckForModel
(
context
.
Background
(),
MonitorProviderAnthropic
,
endpoint
,
"sk-fake"
,
"claude-x"
,
opts
)
if
res
.
Status
!=
MonitorStatusFailed
{
t
.
Errorf
(
"replace mode with empty text should be failed, got status=%s"
,
res
.
Status
)
}
if
!
strings
.
Contains
(
res
.
Message
,
"replace-mode"
)
{
t
.
Errorf
(
"failure message should hint replace-mode, got %q"
,
res
.
Message
)
}
}
backend/internal/service/channel_monitor_const.go
0 → 100644
View file @
ac114738
package
service
import
(
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// ChannelMonitor 全局常量。
// 这些是 MVP 阶段的硬编码值,按需可以提到 config 中。
const
(
// monitorRequestTimeout 单次模型请求总超时(含 Body 读取)。
monitorRequestTimeout
=
45
*
time
.
Second
// monitorPingTimeout HEAD 请求 endpoint origin 的超时。
monitorPingTimeout
=
8
*
time
.
Second
// monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。
monitorDegradedThreshold
=
6
*
time
.
Second
// monitorHistoryRetentionDays 明细历史保留天数。
// 60s 默认间隔 * 30 天 ≈ 43200 行/monitor/model,一般部署总量 <= 2M 行,
// PG 无压力;所以直接保留完整明细一个月,可用率查询可以全走原始行不依赖聚合。
// 聚合表 channel_monitor_daily_rollups 仍然保留,作为长期历史回填/降级查询的兜底。
monitorHistoryRetentionDays
=
30
// monitorRollupRetentionDays 日聚合保留天数。
// 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。
monitorRollupRetentionDays
=
30
// monitorMaintenanceMaxDaysPerRun 单次维护任务最多聚合的天数。
// 用于限制首次上线回填(30 天)+ 少量余量,避免长事务。
monitorMaintenanceMaxDaysPerRun
=
35
// monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。
monitorWorkerConcurrency
=
5
// monitorStartupLoadTimeout Start 时一次性加载所有 enabled monitor 的总超时。
monitorStartupLoadTimeout
=
10
*
time
.
Second
// monitorMinIntervalSeconds / monitorMaxIntervalSeconds 用户配置的检测间隔上下限。
monitorMinIntervalSeconds
=
15
monitorMaxIntervalSeconds
=
3600
// monitorMessageMaxBytes message 字段最大字节数(与 schema/migration 一致)。
monitorMessageMaxBytes
=
500
// monitorResponseMaxBytes 单次模型响应最大读取字节,防止 OOM。
monitorResponseMaxBytes
=
64
*
1024
// monitorErrorBodySnippetMaxBytes 非 2xx 响应时保留上游 body 片段的最大字节数。
// 留 300 字节足够覆盖典型结构化错误(如 `{"error":{"message":"..."}}`),
// 又给 "upstream HTTP <status>: " 前缀留出余量,避免最终被 monitorMessageMaxBytes (500) 截得太狠。
monitorErrorBodySnippetMaxBytes
=
300
// monitorChallengeMin / monitorChallengeMax challenge 操作数范围。
monitorChallengeMin
=
1
monitorChallengeMax
=
50
// providerOpenAIPath OpenAI Chat Completions 路径。
providerOpenAIPath
=
"/v1/chat/completions"
// providerAnthropicPath Anthropic Messages 路径。
providerAnthropicPath
=
"/v1/messages"
// providerGeminiPathTemplate Gemini generateContent 路径模板(含 model 占位)。
providerGeminiPathTemplate
=
"/v1beta/models/%s:generateContent"
// MonitorProviderOpenAI / Anthropic / Gemini provider 字符串常量(也是 ent enum 的实际值)。
MonitorProviderOpenAI
=
"openai"
MonitorProviderAnthropic
=
"anthropic"
MonitorProviderGemini
=
"gemini"
// MonitorStatusOperational 等监控状态字符串常量(与 ent enum 一致)。
MonitorStatusOperational
=
"operational"
MonitorStatusDegraded
=
"degraded"
MonitorStatusFailed
=
"failed"
MonitorStatusError
=
"error"
// monitorAvailability7Days / 15 / 30 用于聚合查询窗口。
monitorAvailability7Days
=
7
monitorAvailability15Days
=
15
monitorAvailability30Days
=
30
// MonitorHistoryDefaultLimit 历史查询默认返回条数(handler 层共享)。
MonitorHistoryDefaultLimit
=
100
// MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。
MonitorHistoryMaxLimit
=
1000
// monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。
monitorTimelineMaxPoints
=
60
// monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。
monitorEndpointResolveTimeout
=
5
*
time
.
Second
// ---- checker / runner 行为参数(消除 magic 值)----
// monitorAnthropicAPIVersion Anthropic Messages API 版本头。
monitorAnthropicAPIVersion
=
"2023-06-01"
// monitorChallengeMaxTokens 单次 challenge 请求的 max_tokens(足够回答个位数算术)。
monitorChallengeMaxTokens
=
50
// monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。
monitorRunOneBuffer
=
10
*
time
.
Second
// monitorIdleConnTimeout HTTP transport 空闲连接关闭超时。
monitorIdleConnTimeout
=
30
*
time
.
Second
// monitorTLSHandshakeTimeout HTTP transport TLS 握手超时。
monitorTLSHandshakeTimeout
=
10
*
time
.
Second
// monitorResponseHeaderTimeout HTTP transport 等待响应头超时。
monitorResponseHeaderTimeout
=
30
*
time
.
Second
// monitorPingDiscardMaxBytes ping 时丢弃响应体的最大字节数。
monitorPingDiscardMaxBytes
=
1024
// monitorDialTimeout 自定义 dialer 单次连接超时。
monitorDialTimeout
=
10
*
time
.
Second
// monitorDialKeepAlive 自定义 dialer keep-alive 间隔。
monitorDialKeepAlive
=
30
*
time
.
Second
)
// 业务错误(统一在此声明,避免散落)。
var
(
ErrChannelMonitorNotFound
=
infraerrors
.
NotFound
(
"CHANNEL_MONITOR_NOT_FOUND"
,
"channel monitor not found"
,
)
ErrChannelMonitorInvalidProvider
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_INVALID_PROVIDER"
,
"provider must be one of openai/anthropic/gemini"
,
)
ErrChannelMonitorInvalidInterval
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_INVALID_INTERVAL"
,
"interval_seconds must be in [15, 3600]"
,
)
ErrChannelMonitorInvalidEndpoint
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_INVALID_ENDPOINT"
,
"endpoint must be a valid https URL"
,
)
ErrChannelMonitorEndpointScheme
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_ENDPOINT_SCHEME"
,
"endpoint must use https scheme"
,
)
ErrChannelMonitorEndpointPath
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_ENDPOINT_PATH"
,
"endpoint must be base origin only (no path/query/fragment)"
,
)
ErrChannelMonitorEndpointPrivate
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_ENDPOINT_PRIVATE"
,
"endpoint must be a public host"
,
)
ErrChannelMonitorEndpointUnreachable
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_ENDPOINT_UNREACHABLE"
,
"endpoint hostname could not be resolved"
,
)
ErrChannelMonitorMissingAPIKey
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_MISSING_API_KEY"
,
"api_key is required when creating a monitor"
,
)
ErrChannelMonitorMissingPrimaryModel
=
infraerrors
.
BadRequest
(
"CHANNEL_MONITOR_MISSING_PRIMARY_MODEL"
,
"primary_model is required"
,
)
ErrChannelMonitorAPIKeyDecryptFailed
=
infraerrors
.
InternalServer
(
"CHANNEL_MONITOR_KEY_DECRYPT_FAILED"
,
"api key decryption failed; please re-edit the monitor with a fresh key"
,
)
)
backend/internal/service/channel_monitor_runner.go
0 → 100644
View file @
ac114738
package
service
import
(
"context"
"log/slog"
"sync"
"time"
"github.com/alitto/pond/v2"
)
// MonitorScheduler 调度器接口,供 ChannelMonitorService 在 CRUD 时回调,
// 用 setter 注入避免 service ↔ runner 的 wire 依赖环。
type
MonitorScheduler
interface
{
// Schedule 为指定监控创建(或重置)独立定时任务。
// 当 m.Enabled=false 时等同于 Unschedule(m.ID)。
Schedule
(
m
*
ChannelMonitor
)
// Unschedule 取消指定监控的定时任务(若存在)。
Unschedule
(
id
int64
)
}
// monitorRunnerSvc 抽出 runner 实际依赖的两个 service 方法:
// - 启动时加载 enabled monitor
// - 每次 ticker 触发执行检测
//
// 用接口而非 *ChannelMonitorService 是为了让 runner 单元测试可注入轻量 stub,
// 避免依赖完整的 repo + encryptor 链路。生产实现 *ChannelMonitorService 自然满足。
type
monitorRunnerSvc
interface
{
ListEnabledMonitors
(
ctx
context
.
Context
)
([]
*
ChannelMonitor
,
error
)
RunCheck
(
ctx
context
.
Context
,
id
int64
)
([]
*
CheckResult
,
error
)
}
// ChannelMonitorRunner 渠道监控调度器。
//
// 设计:
// - 每个 enabled monitor 对应一个独立 goroutine + ticker(按各自 IntervalSeconds)
// - Start 时一次性加载所有 enabled monitor 并为每个建立任务
// - Service 在 Create/Update/Delete 后通过 MonitorScheduler 接口回调,
// 即时重建/取消对应任务(无需轮询 DB)
// - 实际 HTTP 检测交给 pond 池(容量 monitorWorkerConcurrency),
// 防止突发并发拖垮上游
//
// 历史清理与日聚合维护由 OpsCleanupService 的 cron 触发
// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat),
// 不在 runner 职责内。
type
ChannelMonitorRunner
struct
{
svc
monitorRunnerSvc
settingService
*
SettingService
pool
pond
.
Pool
parentCtx
context
.
Context
parentCancel
context
.
CancelFunc
mu
sync
.
Mutex
tasks
map
[
int64
]
*
scheduledMonitor
wg
sync
.
WaitGroup
started
bool
stopped
bool
// inFlight 跟踪正在执行的 monitor.ID。fire 调度前会检查避免重复提交,
// 防止单次检测耗时 > interval 时同一 monitor 被并发执行。
inFlight
map
[
int64
]
struct
{}
inFlightMu
sync
.
Mutex
}
// scheduledMonitor 单个监控的运行时上下文。
type
scheduledMonitor
struct
{
id
int64
name
string
interval
time
.
Duration
cancel
context
.
CancelFunc
}
// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。
// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
//
// pool 在构造时即建好:避免 Start 在 mu 内赋值、fire/Stop 在 mu 外读取的竞态隐患,
// 且 pond.NewPool 创建本身近似零开销,提前建池不会浪费资源。
func
NewChannelMonitorRunner
(
svc
*
ChannelMonitorService
,
settingService
*
SettingService
)
*
ChannelMonitorRunner
{
return
newChannelMonitorRunner
(
svc
,
settingService
)
}
// newChannelMonitorRunner 内部构造,接受最小化接口,便于单元测试注入 stub。
func
newChannelMonitorRunner
(
svc
monitorRunnerSvc
,
settingService
*
SettingService
)
*
ChannelMonitorRunner
{
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
return
&
ChannelMonitorRunner
{
svc
:
svc
,
settingService
:
settingService
,
pool
:
pond
.
NewPool
(
monitorWorkerConcurrency
),
parentCtx
:
ctx
,
parentCancel
:
cancel
,
tasks
:
make
(
map
[
int64
]
*
scheduledMonitor
),
inFlight
:
make
(
map
[
int64
]
struct
{}),
}
}
// Start 加载所有 enabled monitor 并为每个建立独立定时任务。
// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。
func
(
r
*
ChannelMonitorRunner
)
Start
()
{
if
r
==
nil
||
r
.
svc
==
nil
{
return
}
r
.
mu
.
Lock
()
if
r
.
started
||
r
.
stopped
{
r
.
mu
.
Unlock
()
return
}
r
.
started
=
true
r
.
mu
.
Unlock
()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
monitorStartupLoadTimeout
)
defer
cancel
()
enabled
,
err
:=
r
.
svc
.
ListEnabledMonitors
(
ctx
)
if
err
!=
nil
{
slog
.
Error
(
"channel_monitor: load enabled monitors failed at startup"
,
"error"
,
err
)
return
}
for
_
,
m
:=
range
enabled
{
r
.
Schedule
(
m
)
}
slog
.
Info
(
"channel_monitor: runner started"
,
"scheduled_tasks"
,
len
(
enabled
))
}
// Schedule 为指定监控创建(或重置)独立定时任务。
// - m.Enabled=false → 等同于 Unschedule(m.ID)
// - 已存在的任务会先被取消再重建(适用于 IntervalSeconds 变更场景)
// - 新任务立即触发首次检测,之后按 IntervalSeconds 周期触发
func
(
r
*
ChannelMonitorRunner
)
Schedule
(
m
*
ChannelMonitor
)
{
if
r
==
nil
||
m
==
nil
{
return
}
if
!
m
.
Enabled
{
r
.
Unschedule
(
m
.
ID
)
return
}
interval
:=
time
.
Duration
(
m
.
IntervalSeconds
)
*
time
.
Second
if
interval
<=
0
{
// Create/Update 已通过 validateInterval 校验区间,正常路径不可能到这里。
// 真触发说明数据库中存在违反约束的数据或校验链路有 bug,记 Error 暴露问题。
slog
.
Error
(
"channel_monitor: skip schedule for invalid interval"
,
"monitor_id"
,
m
.
ID
,
"interval_seconds"
,
m
.
IntervalSeconds
)
return
}
r
.
mu
.
Lock
()
if
r
.
stopped
{
r
.
mu
.
Unlock
()
return
}
if
!
r
.
started
{
// Start 之前调用 Schedule 通常意味着 wire 顺序错乱:
// 当前 wire 顺序是 SetScheduler → Start,CRUD 钩子最早也只能在请求到达时触发,
// 此时 Start 早已完成。出现此分支时把 monitor 信息打出来便于排查,
// 不入队、不缓存——交给运维通过重启或修复 wire 解决。
r
.
mu
.
Unlock
()
slog
.
Warn
(
"channel_monitor: schedule before runner started, skip"
,
"monitor_id"
,
m
.
ID
,
"name"
,
m
.
Name
)
return
}
if
existing
,
ok
:=
r
.
tasks
[
m
.
ID
];
ok
{
existing
.
cancel
()
}
ctx
,
cancel
:=
context
.
WithCancel
(
r
.
parentCtx
)
task
:=
&
scheduledMonitor
{
id
:
m
.
ID
,
name
:
m
.
Name
,
interval
:
interval
,
cancel
:
cancel
,
}
r
.
tasks
[
m
.
ID
]
=
task
r
.
wg
.
Add
(
1
)
r
.
mu
.
Unlock
()
go
r
.
runScheduled
(
ctx
,
task
)
}
// Unschedule 取消指定监控的定时任务(若存在)。
// 已经在执行中的检测会通过 ctx 取消信号传递。
func
(
r
*
ChannelMonitorRunner
)
Unschedule
(
id
int64
)
{
if
r
==
nil
{
return
}
r
.
mu
.
Lock
()
task
,
ok
:=
r
.
tasks
[
id
]
if
ok
{
delete
(
r
.
tasks
,
id
)
}
r
.
mu
.
Unlock
()
if
ok
{
task
.
cancel
()
}
}
// Stop 优雅停止:取消所有任务、关闭池。
func
(
r
*
ChannelMonitorRunner
)
Stop
()
{
if
r
==
nil
{
return
}
r
.
mu
.
Lock
()
if
r
.
stopped
{
r
.
mu
.
Unlock
()
return
}
r
.
stopped
=
true
r
.
parentCancel
()
r
.
tasks
=
nil
r
.
mu
.
Unlock
()
r
.
wg
.
Wait
()
r
.
pool
.
StopAndWait
()
}
// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"),
// 之后按 interval 周期触发;ctx 取消即退出。
func
(
r
*
ChannelMonitorRunner
)
runScheduled
(
ctx
context
.
Context
,
task
*
scheduledMonitor
)
{
defer
r
.
wg
.
Done
()
r
.
fire
(
ctx
,
task
)
ticker
:=
time
.
NewTicker
(
task
.
interval
)
defer
ticker
.
Stop
()
for
{
select
{
case
<-
ctx
.
Done
()
:
return
case
<-
ticker
.
C
:
r
.
fire
(
ctx
,
task
)
}
}
}
// fire 提交一次检测到 worker 池。功能开关关闭时跳过本次(不取消任务,
// 重新启用时立即恢复);池满或重复在飞时也跳过。
func
(
r
*
ChannelMonitorRunner
)
fire
(
ctx
context
.
Context
,
task
*
scheduledMonitor
)
{
if
r
.
settingService
!=
nil
&&
!
r
.
settingService
.
GetChannelMonitorRuntime
(
ctx
)
.
Enabled
{
return
}
if
!
r
.
tryAcquireInFlight
(
task
.
id
)
{
slog
.
Debug
(
"channel_monitor: skip already in-flight"
,
"monitor_id"
,
task
.
id
,
"name"
,
task
.
name
)
return
}
if
_
,
ok
:=
r
.
pool
.
TrySubmit
(
func
()
{
r
.
runOne
(
task
.
id
,
task
.
name
)
});
!
ok
{
// 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。
r
.
releaseInFlight
(
task
.
id
)
slog
.
Warn
(
"channel_monitor: worker pool full, skip submission"
,
"monitor_id"
,
task
.
id
,
"name"
,
task
.
name
)
}
}
// tryAcquireInFlight 原子地占用 monitor 的 in-flight 槽。
// 已被占用返回 false(调用方应跳过本次提交)。
func
(
r
*
ChannelMonitorRunner
)
tryAcquireInFlight
(
id
int64
)
bool
{
r
.
inFlightMu
.
Lock
()
defer
r
.
inFlightMu
.
Unlock
()
if
_
,
exists
:=
r
.
inFlight
[
id
];
exists
{
return
false
}
r
.
inFlight
[
id
]
=
struct
{}{}
return
true
}
// releaseInFlight 释放 in-flight 槽。runOne 完成(含 panic recover)后必须调用。
func
(
r
*
ChannelMonitorRunner
)
releaseInFlight
(
id
int64
)
{
r
.
inFlightMu
.
Lock
()
delete
(
r
.
inFlight
,
id
)
r
.
inFlightMu
.
Unlock
()
}
// runOne 执行单个监控的检测。所有错误只记日志,不熔断。
// 任务结束时(含 panic recover)必须释放 in-flight 槽。
func
(
r
*
ChannelMonitorRunner
)
runOne
(
id
int64
,
name
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
monitorRequestTimeout
+
monitorPingTimeout
+
monitorRunOneBuffer
)
defer
cancel
()
defer
r
.
releaseInFlight
(
id
)
defer
func
()
{
if
rec
:=
recover
();
rec
!=
nil
{
slog
.
Error
(
"channel_monitor: runner panic"
,
"monitor_id"
,
id
,
"name"
,
name
,
"panic"
,
rec
)
}
}()
if
_
,
err
:=
r
.
svc
.
RunCheck
(
ctx
,
id
);
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: run check failed"
,
"monitor_id"
,
id
,
"name"
,
name
,
"error"
,
err
)
}
}
backend/internal/service/channel_monitor_runner_test.go
0 → 100644
View file @
ac114738
//go:build unit
package
service
import
(
"context"
"sync"
"sync/atomic"
"testing"
"time"
)
// stubMonitorSvc 实现 monitorRunnerSvc,用于隔离 runner 与真实 service/repo。
type
stubMonitorSvc
struct
{
enabled
[]
*
ChannelMonitor
runCount
atomic
.
Int64
runCalled
chan
int64
// 每次 RunCheck 触发时 push 一次(缓冲足够大避免阻塞)
runErr
error
listErr
error
runHoldFor
time
.
Duration
// RunCheck 内额外阻塞的时长,用来测试 Stop 等待行为
}
func
(
s
*
stubMonitorSvc
)
ListEnabledMonitors
(
_
context
.
Context
)
([]
*
ChannelMonitor
,
error
)
{
if
s
.
listErr
!=
nil
{
return
nil
,
s
.
listErr
}
return
s
.
enabled
,
nil
}
func
(
s
*
stubMonitorSvc
)
RunCheck
(
ctx
context
.
Context
,
id
int64
)
([]
*
CheckResult
,
error
)
{
s
.
runCount
.
Add
(
1
)
if
s
.
runCalled
!=
nil
{
select
{
case
s
.
runCalled
<-
id
:
default
:
}
}
if
s
.
runHoldFor
>
0
{
select
{
case
<-
time
.
After
(
s
.
runHoldFor
)
:
case
<-
ctx
.
Done
()
:
}
}
return
nil
,
s
.
runErr
}
func
newRunnerForTest
(
svc
monitorRunnerSvc
)
*
ChannelMonitorRunner
{
return
newChannelMonitorRunner
(
svc
,
nil
)
}
// 等待 condition 在 timeout 内变 true,否则 t.Fatalf。轮询 5ms 一次。
func
waitFor
(
t
*
testing
.
T
,
timeout
time
.
Duration
,
msg
string
,
cond
func
()
bool
)
{
t
.
Helper
()
deadline
:=
time
.
Now
()
.
Add
(
timeout
)
for
time
.
Now
()
.
Before
(
deadline
)
{
if
cond
()
{
return
}
time
.
Sleep
(
5
*
time
.
Millisecond
)
}
if
!
cond
()
{
t
.
Fatalf
(
"waitFor timed out: %s"
,
msg
)
}
}
func
runnerTaskCount
(
r
*
ChannelMonitorRunner
)
int
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
return
len
(
r
.
tasks
)
}
func
runnerTaskPtr
(
r
*
ChannelMonitorRunner
,
id
int64
)
*
scheduledMonitor
{
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
return
r
.
tasks
[
id
]
}
// TestSchedule_AddsTaskAndFiresOnce 验证 Schedule 后立即触发一次首检测,并把任务记入 tasks 表。
func
TestSchedule_AddsTaskAndFiresOnce
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
runCalled
:
make
(
chan
int64
,
4
)}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
// svc.enabled 为空,Start 立即完成
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
1
,
Name
:
"m1"
,
Enabled
:
true
,
IntervalSeconds
:
60
})
if
got
:=
runnerTaskCount
(
r
);
got
!=
1
{
t
.
Fatalf
(
"expected 1 scheduled task, got %d"
,
got
)
}
select
{
case
id
:=
<-
svc
.
runCalled
:
if
id
!=
1
{
t
.
Fatalf
(
"expected first fire for id=1, got %d"
,
id
)
}
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"expected immediate first fire within 2s"
)
}
r
.
Stop
()
}
// TestSchedule_ReplaceCancelsOldTask 验证对同一 id 二次 Schedule 会替换旧 task 实例。
// (旧 goroutine 通过 ctx 取消退出;这里以 task 指针不同 + Stop 不超时作为证据。)
func
TestSchedule_ReplaceCancelsOldTask
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
runCalled
:
make
(
chan
int64
,
8
)}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
m
:=
&
ChannelMonitor
{
ID
:
7
,
Name
:
"m7"
,
Enabled
:
true
,
IntervalSeconds
:
60
}
r
.
Schedule
(
m
)
first
:=
runnerTaskPtr
(
r
,
7
)
if
first
==
nil
{
t
.
Fatal
(
"first schedule did not register task"
)
}
r
.
Schedule
(
m
)
second
:=
runnerTaskPtr
(
r
,
7
)
if
second
==
nil
{
t
.
Fatal
(
"second schedule did not register task"
)
}
if
first
==
second
{
t
.
Fatal
(
"re-Schedule should create a new scheduledMonitor instance"
)
}
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
}
// TestUnschedule_RemovesTask 验证 Unschedule 删除 task 并使对应 goroutine 退出。
func
TestUnschedule_RemovesTask
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
runCalled
:
make
(
chan
int64
,
4
)}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
3
,
Enabled
:
true
,
IntervalSeconds
:
60
})
waitFor
(
t
,
time
.
Second
,
"task registered"
,
func
()
bool
{
return
runnerTaskCount
(
r
)
==
1
})
r
.
Unschedule
(
3
)
if
got
:=
runnerTaskCount
(
r
);
got
!=
0
{
t
.
Fatalf
(
"expected tasks empty after Unschedule, got %d"
,
got
)
}
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
}
// TestSchedule_DisabledRedirectsToUnschedule 验证 Enabled=false 等同于 Unschedule。
func
TestSchedule_DisabledRedirectsToUnschedule
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
runCalled
:
make
(
chan
int64
,
4
)}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
9
,
Enabled
:
true
,
IntervalSeconds
:
60
})
waitFor
(
t
,
time
.
Second
,
"task registered"
,
func
()
bool
{
return
runnerTaskCount
(
r
)
==
1
})
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
9
,
Enabled
:
false
,
IntervalSeconds
:
60
})
if
got
:=
runnerTaskCount
(
r
);
got
!=
0
{
t
.
Fatalf
(
"expected tasks empty after disabled re-Schedule, got %d"
,
got
)
}
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
}
// TestSchedule_InvalidIntervalSkipped 验证 IntervalSeconds<=0 不会注册任务(防御性检查)。
func
TestSchedule_InvalidIntervalSkipped
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
1
,
Enabled
:
true
,
IntervalSeconds
:
0
})
if
got
:=
runnerTaskCount
(
r
);
got
!=
0
{
t
.
Fatalf
(
"expected no task for invalid interval, got %d"
,
got
)
}
r
.
Stop
()
}
// TestSchedule_BeforeStartIsNoOp 验证 Start 之前调用 Schedule 不会注册任务。
func
TestSchedule_BeforeStartIsNoOp
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{}
r
:=
newRunnerForTest
(
svc
)
// 故意不调用 Start
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
1
,
Enabled
:
true
,
IntervalSeconds
:
60
})
if
got
:=
runnerTaskCount
(
r
);
got
!=
0
{
t
.
Fatalf
(
"expected no task before Start, got %d"
,
got
)
}
r
.
Stop
()
}
// TestStart_LoadsAllEnabledMonitors 验证 Start 会为 ListEnabledMonitors 返回的每条记录建立任务。
func
TestStart_LoadsAllEnabledMonitors
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
enabled
:
[]
*
ChannelMonitor
{
{
ID
:
1
,
Enabled
:
true
,
IntervalSeconds
:
60
},
{
ID
:
2
,
Enabled
:
true
,
IntervalSeconds
:
60
},
{
ID
:
3
,
Enabled
:
true
,
IntervalSeconds
:
60
},
},
}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
waitFor
(
t
,
2
*
time
.
Second
,
"all 3 tasks scheduled"
,
func
()
bool
{
return
runnerTaskCount
(
r
)
==
3
})
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
}
// TestStop_DrainsAllGoroutines 验证 Stop 会等待所有调度 goroutine 退出(无游离)。
func
TestStop_DrainsAllGoroutines
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
for
id
:=
int64
(
1
);
id
<=
5
;
id
++
{
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
id
,
Enabled
:
true
,
IntervalSeconds
:
60
})
}
waitFor
(
t
,
2
*
time
.
Second
,
"5 tasks scheduled"
,
func
()
bool
{
return
runnerTaskCount
(
r
)
==
5
})
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
}
// TestStop_WaitsForInFlightCheck 验证 Stop 会等待正在执行的 RunCheck 退出(pool.StopAndWait)。
func
TestStop_WaitsForInFlightCheck
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{
runCalled
:
make
(
chan
int64
,
1
),
runHoldFor
:
200
*
time
.
Millisecond
,
}
r
:=
newRunnerForTest
(
svc
)
r
.
Start
()
r
.
Schedule
(
&
ChannelMonitor
{
ID
:
1
,
Enabled
:
true
,
IntervalSeconds
:
60
})
select
{
case
<-
svc
.
runCalled
:
case
<-
time
.
After
(
2
*
time
.
Second
)
:
t
.
Fatal
(
"first fire never happened"
)
}
start
:=
time
.
Now
()
stoppedWithin
(
t
,
r
,
3
*
time
.
Second
)
elapsed
:=
time
.
Since
(
start
)
// Stop 必须等待 in-flight check 跑完(runHoldFor=200ms),耗时下界约 100ms。
if
elapsed
<
100
*
time
.
Millisecond
{
t
.
Fatalf
(
"Stop returned too fast (%v); did not wait for in-flight check"
,
elapsed
)
}
}
// TestInFlight_PoolFullReleasesSlot 直接驱动 fire 路径,模拟 pool.TrySubmit 失败时 inFlight 必须释放。
// 用一个小型 stub pool 替换 r.pool 不便(pond.Pool 是接口但 mock 麻烦),
// 改为:占满 inFlight 后直接 fire,验证不会在 inFlight 空槽时永久卡住。
func
TestInFlight_AcquireReleaseSymmetric
(
t
*
testing
.
T
)
{
svc
:=
&
stubMonitorSvc
{}
r
:=
newRunnerForTest
(
svc
)
if
!
r
.
tryAcquireInFlight
(
42
)
{
t
.
Fatal
(
"first acquire should succeed"
)
}
if
r
.
tryAcquireInFlight
(
42
)
{
t
.
Fatal
(
"second acquire (no release) must fail"
)
}
r
.
releaseInFlight
(
42
)
if
!
r
.
tryAcquireInFlight
(
42
)
{
t
.
Fatal
(
"acquire after release should succeed"
)
}
r
.
releaseInFlight
(
42
)
}
// stoppedWithin 在 timeout 内并行调用 Stop,超时则 Fatal。验证 Stop 不会阻塞。
func
stoppedWithin
(
t
*
testing
.
T
,
r
*
ChannelMonitorRunner
,
timeout
time
.
Duration
)
{
t
.
Helper
()
done
:=
make
(
chan
struct
{})
var
once
sync
.
Once
go
func
()
{
r
.
Stop
()
once
.
Do
(
func
()
{
close
(
done
)
})
}()
select
{
case
<-
done
:
case
<-
time
.
After
(
timeout
)
:
t
.
Fatalf
(
"Stop did not return within %s — leaked goroutine?"
,
timeout
)
}
}
backend/internal/service/channel_monitor_service.go
0 → 100644
View file @
ac114738
package
service
import
(
"context"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
)
// ChannelMonitorRepository 渠道监控数据访问接口。
// 入参/返回的指针类型均使用 service 包的 ChannelMonitor 模型,
// repository 实现负责与 ent 模型互转,并保持 api_key_encrypted 字段为密文。
type
ChannelMonitorRepository
interface
{
// CRUD
Create
(
ctx
context
.
Context
,
m
*
ChannelMonitor
)
error
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
ChannelMonitor
,
error
)
Update
(
ctx
context
.
Context
,
m
*
ChannelMonitor
)
error
Delete
(
ctx
context
.
Context
,
id
int64
)
error
List
(
ctx
context
.
Context
,
params
ChannelMonitorListParams
)
([]
*
ChannelMonitor
,
int64
,
error
)
// 调度器辅助
ListEnabled
(
ctx
context
.
Context
)
([]
*
ChannelMonitor
,
error
)
MarkChecked
(
ctx
context
.
Context
,
id
int64
,
checkedAt
time
.
Time
)
error
InsertHistoryBatch
(
ctx
context
.
Context
,
rows
[]
*
ChannelMonitorHistoryRow
)
error
DeleteHistoryBefore
(
ctx
context
.
Context
,
before
time
.
Time
)
(
int64
,
error
)
// 历史记录
ListHistory
(
ctx
context
.
Context
,
monitorID
int64
,
model
string
,
limit
int
)
([]
*
ChannelMonitorHistoryEntry
,
error
)
// 用户视图聚合
ListLatestPerModel
(
ctx
context
.
Context
,
monitorID
int64
)
([]
*
ChannelMonitorLatest
,
error
)
ComputeAvailability
(
ctx
context
.
Context
,
monitorID
int64
,
windowDays
int
)
([]
*
ChannelMonitorAvailability
,
error
)
// 批量聚合(admin/user list 用,避免 N+1)
ListLatestForMonitorIDs
(
ctx
context
.
Context
,
ids
[]
int64
)
(
map
[
int64
][]
*
ChannelMonitorLatest
,
error
)
ComputeAvailabilityForMonitors
(
ctx
context
.
Context
,
ids
[]
int64
,
windowDays
int
)
(
map
[
int64
][]
*
ChannelMonitorAvailability
,
error
)
// ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。
// 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。
ListRecentHistoryForMonitors
(
ctx
context
.
Context
,
ids
[]
int64
,
primaryModels
map
[
int64
]
string
,
perMonitorLimit
int
)
(
map
[
int64
][]
*
ChannelMonitorHistoryEntry
,
error
)
// ---------- 聚合维护(OpsCleanupService 调用) ----------
// UpsertDailyRollupsFor 把 targetDate 当天的明细按 (monitor_id, model, bucket_date)
// 聚合到 channel_monitor_daily_rollups。targetDate 会被截断到日期;
// 用 ON CONFLICT DO UPDATE 实现幂等回填,返回 upsert 影响的行数。
UpsertDailyRollupsFor
(
ctx
context
.
Context
,
targetDate
time
.
Time
)
(
int64
,
error
)
// DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行,返回删除行数。
DeleteRollupsBefore
(
ctx
context
.
Context
,
beforeDate
time
.
Time
)
(
int64
,
error
)
// LoadAggregationWatermark 读 watermark(id=1)。
// 返回 nil 表示从未聚合过;watermark 表本身预期已存在单行(migration 110 写入)。
LoadAggregationWatermark
(
ctx
context
.
Context
)
(
*
time
.
Time
,
error
)
// UpdateAggregationWatermark 写 watermark(UPSERT 到 id=1)。
UpdateAggregationWatermark
(
ctx
context
.
Context
,
date
time
.
Time
)
error
}
// ChannelMonitorService 渠道监控管理服务。
type
ChannelMonitorService
struct
{
repo
ChannelMonitorRepository
encryptor
SecretEncryptor
// scheduler 由 wire 通过 SetScheduler 注入;CRUD 后调用对应钩子即时同步任务。
// 测试或未注入场景下保持 nil,所有钩子调用变为 no-op。
scheduler
MonitorScheduler
}
// NewChannelMonitorService 创建渠道监控服务实例。
func
NewChannelMonitorService
(
repo
ChannelMonitorRepository
,
encryptor
SecretEncryptor
)
*
ChannelMonitorService
{
return
&
ChannelMonitorService
{
repo
:
repo
,
encryptor
:
encryptor
}
}
// ---------- CRUD ----------
// List 列表查询(支持 provider/enabled/search 过滤 + 分页)。
// 返回的 ChannelMonitor.APIKey 已解密为明文,handler 层负责脱敏。
func
(
s
*
ChannelMonitorService
)
List
(
ctx
context
.
Context
,
params
ChannelMonitorListParams
)
([]
*
ChannelMonitor
,
int64
,
error
)
{
if
params
.
Page
<
1
{
params
.
Page
=
1
}
if
params
.
PageSize
<
1
||
params
.
PageSize
>
200
{
params
.
PageSize
=
20
}
items
,
total
,
err
:=
s
.
repo
.
List
(
ctx
,
params
)
if
err
!=
nil
{
return
nil
,
0
,
fmt
.
Errorf
(
"list channel monitors: %w"
,
err
)
}
for
_
,
it
:=
range
items
{
s
.
decryptInPlace
(
it
)
}
return
items
,
total
,
nil
}
// Get 查询单个监控(解密 API Key)。
func
(
s
*
ChannelMonitorService
)
Get
(
ctx
context
.
Context
,
id
int64
)
(
*
ChannelMonitor
,
error
)
{
m
,
err
:=
s
.
repo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
s
.
decryptInPlace
(
m
)
return
m
,
nil
}
// Create 创建监控(内部加密 api_key)。
func
(
s
*
ChannelMonitorService
)
Create
(
ctx
context
.
Context
,
p
ChannelMonitorCreateParams
)
(
*
ChannelMonitor
,
error
)
{
if
err
:=
validateCreateParams
(
p
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
validateBodyModeParams
(
p
.
BodyOverrideMode
,
p
.
BodyOverride
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
validateExtraHeaders
(
p
.
ExtraHeaders
);
err
!=
nil
{
return
nil
,
err
}
encrypted
,
err
:=
s
.
encryptor
.
Encrypt
(
p
.
APIKey
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"encrypt api key: %w"
,
err
)
}
m
:=
&
ChannelMonitor
{
Name
:
strings
.
TrimSpace
(
p
.
Name
),
Provider
:
p
.
Provider
,
Endpoint
:
normalizeEndpoint
(
p
.
Endpoint
),
APIKey
:
encrypted
,
// 注意:传入 repository 时该字段为密文
PrimaryModel
:
strings
.
TrimSpace
(
p
.
PrimaryModel
),
ExtraModels
:
normalizeModels
(
p
.
ExtraModels
),
GroupName
:
strings
.
TrimSpace
(
p
.
GroupName
),
Enabled
:
p
.
Enabled
,
IntervalSeconds
:
p
.
IntervalSeconds
,
CreatedBy
:
p
.
CreatedBy
,
TemplateID
:
p
.
TemplateID
,
ExtraHeaders
:
emptyHeadersIfNil
(
p
.
ExtraHeaders
),
BodyOverrideMode
:
defaultBodyMode
(
p
.
BodyOverrideMode
),
BodyOverride
:
p
.
BodyOverride
,
}
if
err
:=
s
.
repo
.
Create
(
ctx
,
m
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"create channel monitor: %w"
,
err
)
}
// 不再调 s.Get 重走解密链:已知刚加密的明文,直接构造响应。
// 这样可避免 SecretEncryptor 解密失败时 APIKey 被静默清空的问题(见 Fix 4)。
m
.
APIKey
=
strings
.
TrimSpace
(
p
.
APIKey
)
if
s
.
scheduler
!=
nil
{
s
.
scheduler
.
Schedule
(
m
)
}
return
m
,
nil
}
// validateCreateParams 把 Create 入参的所有校验聚拢为一个函数,避免 Create 主体超过 30 行。
func
validateCreateParams
(
p
ChannelMonitorCreateParams
)
error
{
if
err
:=
validateProvider
(
p
.
Provider
);
err
!=
nil
{
return
err
}
if
err
:=
validateInterval
(
p
.
IntervalSeconds
);
err
!=
nil
{
return
err
}
if
err
:=
validateEndpoint
(
p
.
Endpoint
);
err
!=
nil
{
return
err
}
if
strings
.
TrimSpace
(
p
.
APIKey
)
==
""
{
return
ErrChannelMonitorMissingAPIKey
}
if
strings
.
TrimSpace
(
p
.
PrimaryModel
)
==
""
{
return
ErrChannelMonitorMissingPrimaryModel
}
return
nil
}
// Update 更新监控。APIKey 字段:nil 或空字符串 = 不修改;非空 = 加密后覆盖。
func
(
s
*
ChannelMonitorService
)
Update
(
ctx
context
.
Context
,
id
int64
,
p
ChannelMonitorUpdateParams
)
(
*
ChannelMonitor
,
error
)
{
existing
,
err
:=
s
.
repo
.
GetByID
(
ctx
,
id
)
if
err
!=
nil
{
return
nil
,
err
}
if
err
:=
applyMonitorUpdate
(
existing
,
p
);
err
!=
nil
{
return
nil
,
err
}
newPlainAPIKey
,
apiKeyUpdated
,
err
:=
s
.
applyAPIKeyUpdate
(
existing
,
p
.
APIKey
)
if
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
repo
.
Update
(
ctx
,
existing
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update channel monitor: %w"
,
err
)
}
// 不再调 s.Get 重走解密链:避免二次解密带来的"密文被静默清空"风险(与 Create 一致)。
if
apiKeyUpdated
{
existing
.
APIKey
=
newPlainAPIKey
}
else
{
s
.
decryptInPlace
(
existing
)
}
if
s
.
scheduler
!=
nil
{
// Schedule 内部根据 Enabled 自动选择 Unschedule 或重建任务,
// IntervalSeconds 变化也会被自然吸收(旧 task 取消 + 新 task 用新 interval)。
s
.
scheduler
.
Schedule
(
existing
)
}
return
existing
,
nil
}
// applyAPIKeyUpdate 处理 Update 中的 APIKey 字段:
// - 入参 raw 为 nil 或空白:不修改 existing.APIKey(仍为密文),返回 updated=false
// - 非空:加密后写入 existing.APIKey;同时把明文返回给调用方,
// 供写库成功后塞回 existing 避免把密文吐回客户端
func
(
s
*
ChannelMonitorService
)
applyAPIKeyUpdate
(
existing
*
ChannelMonitor
,
raw
*
string
)
(
plain
string
,
updated
bool
,
err
error
)
{
if
raw
==
nil
||
strings
.
TrimSpace
(
*
raw
)
==
""
{
return
""
,
false
,
nil
}
plain
=
strings
.
TrimSpace
(
*
raw
)
encrypted
,
encErr
:=
s
.
encryptor
.
Encrypt
(
plain
)
if
encErr
!=
nil
{
return
""
,
false
,
fmt
.
Errorf
(
"encrypt api key: %w"
,
encErr
)
}
existing
.
APIKey
=
encrypted
return
plain
,
true
,
nil
}
// Delete 删除监控(历史通过外键 CASCADE 自动清理)。
func
(
s
*
ChannelMonitorService
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
if
err
:=
s
.
repo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
return
fmt
.
Errorf
(
"delete channel monitor: %w"
,
err
)
}
if
s
.
scheduler
!=
nil
{
s
.
scheduler
.
Unschedule
(
id
)
}
return
nil
}
// ListHistory 列出某个监控最近的检测历史。
// model 为空表示返回所有模型;limit <= 0 时使用默认值,超过上限会被截断。
func
(
s
*
ChannelMonitorService
)
ListHistory
(
ctx
context
.
Context
,
id
int64
,
model
string
,
limit
int
)
([]
*
ChannelMonitorHistoryEntry
,
error
)
{
if
_
,
err
:=
s
.
repo
.
GetByID
(
ctx
,
id
);
err
!=
nil
{
return
nil
,
err
}
if
limit
<=
0
{
limit
=
MonitorHistoryDefaultLimit
}
if
limit
>
MonitorHistoryMaxLimit
{
limit
=
MonitorHistoryMaxLimit
}
entries
,
err
:=
s
.
repo
.
ListHistory
(
ctx
,
id
,
strings
.
TrimSpace
(
model
),
limit
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"list history: %w"
,
err
)
}
return
entries
,
nil
}
// ---------- 业务 ----------
// RunCheck 同步触发对一个监控的检测:并发跑 primary + extra 模型,
// 写历史记录并更新 last_checked_at。返回每个模型的检测结果。
func
(
s
*
ChannelMonitorService
)
RunCheck
(
ctx
context
.
Context
,
id
int64
)
([]
*
CheckResult
,
error
)
{
m
,
err
:=
s
.
Get
(
ctx
,
id
)
// 已解密 APIKey
if
err
!=
nil
{
return
nil
,
err
}
if
m
.
APIKeyDecryptFailed
{
return
nil
,
ErrChannelMonitorAPIKeyDecryptFailed
}
results
:=
s
.
runChecksConcurrent
(
ctx
,
m
)
s
.
persistCheckResults
(
ctx
,
m
,
results
)
return
results
,
nil
}
// persistCheckResults 写入本次检测的历史记录并更新 last_checked_at。
// 任一写库失败都只记日志,不影响调用方拿到 results(与 MVP 期望一致:宁可漏记历史也要先返回结果)。
func
(
s
*
ChannelMonitorService
)
persistCheckResults
(
ctx
context
.
Context
,
m
*
ChannelMonitor
,
results
[]
*
CheckResult
)
{
rows
:=
make
([]
*
ChannelMonitorHistoryRow
,
0
,
len
(
results
))
for
_
,
r
:=
range
results
{
rows
=
append
(
rows
,
&
ChannelMonitorHistoryRow
{
MonitorID
:
m
.
ID
,
Model
:
r
.
Model
,
Status
:
r
.
Status
,
LatencyMs
:
r
.
LatencyMs
,
PingLatencyMs
:
r
.
PingLatencyMs
,
Message
:
r
.
Message
,
CheckedAt
:
r
.
CheckedAt
,
})
}
if
err
:=
s
.
repo
.
InsertHistoryBatch
(
ctx
,
rows
);
err
!=
nil
{
slog
.
Error
(
"channel_monitor: insert history failed"
,
"monitor_id"
,
m
.
ID
,
"name"
,
m
.
Name
,
"error"
,
err
)
}
if
err
:=
s
.
repo
.
MarkChecked
(
ctx
,
m
.
ID
,
time
.
Now
());
err
!=
nil
{
slog
.
Error
(
"channel_monitor: mark checked failed"
,
"monitor_id"
,
m
.
ID
,
"error"
,
err
)
}
}
// runChecksConcurrent 对 primary + extra 模型并发执行检测。
// errgroup 仅用于等待,不传播错误(每个 model 失败都已打包进 CheckResult)。
func
(
s
*
ChannelMonitorService
)
runChecksConcurrent
(
ctx
context
.
Context
,
m
*
ChannelMonitor
)
[]
*
CheckResult
{
models
:=
append
([]
string
{
m
.
PrimaryModel
},
m
.
ExtraModels
...
)
results
:=
make
([]
*
CheckResult
,
len
(
models
))
// ping 共享一次,所有模型记录同一个 ping 延迟。
pingMs
:=
pingEndpointOrigin
(
ctx
,
m
.
Endpoint
)
// 所有模型共用同一份 CheckOptions(来自监控的快照字段)。
opts
:=
&
CheckOptions
{
ExtraHeaders
:
m
.
ExtraHeaders
,
BodyOverrideMode
:
m
.
BodyOverrideMode
,
BodyOverride
:
m
.
BodyOverride
,
}
var
eg
errgroup
.
Group
var
mu
sync
.
Mutex
for
i
,
model
:=
range
models
{
i
,
model
:=
i
,
model
eg
.
Go
(
func
()
error
{
r
:=
runCheckForModel
(
ctx
,
m
.
Provider
,
m
.
Endpoint
,
m
.
APIKey
,
model
,
opts
)
r
.
PingLatencyMs
=
pingMs
mu
.
Lock
()
results
[
i
]
=
r
mu
.
Unlock
()
return
nil
})
}
_
=
eg
.
Wait
()
return
results
}
// ---------- 调度器协作 ----------
// SetScheduler 由 wire 在 runner 构造后注入,用于在 CRUD 时即时同步任务表。
// 通过 setter 注入避免 service ↔ runner 的依赖环。
func
(
s
*
ChannelMonitorService
)
SetScheduler
(
sched
MonitorScheduler
)
{
s
.
scheduler
=
sched
}
// ListEnabledMonitors 返回所有 enabled=true 的监控(解密后),供 runner 启动时建立任务表。
func
(
s
*
ChannelMonitorService
)
ListEnabledMonitors
(
ctx
context
.
Context
)
([]
*
ChannelMonitor
,
error
)
{
all
,
err
:=
s
.
repo
.
ListEnabled
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
for
_
,
m
:=
range
all
{
s
.
decryptInPlace
(
m
)
}
return
all
,
nil
}
// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。
// 由 RunDailyMaintenance 调用;SoftDeleteMixin 自动把 DELETE 改为 UPDATE deleted_at。
func
(
s
*
ChannelMonitorService
)
cleanupOldHistory
(
ctx
context
.
Context
)
error
{
before
:=
time
.
Now
()
.
UTC
()
.
AddDate
(
0
,
0
,
-
monitorHistoryRetentionDays
)
deleted
,
err
:=
s
.
repo
.
DeleteHistoryBefore
(
ctx
,
before
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete history before %s: %w"
,
before
.
Format
(
time
.
RFC3339
),
err
)
}
if
deleted
>
0
{
slog
.
Info
(
"channel_monitor: history cleanup"
,
"deleted_rows"
,
deleted
,
"before"
,
before
.
Format
(
time
.
RFC3339
))
}
return
nil
}
// RunDailyMaintenance 每日维护任务:聚合昨天之前未聚合的明细,软删过期明细和聚合。
// 由 OpsCleanupService 的 cron 调度触发(共享 schedule 和 leader lock)。
//
// 幂等性:
// - watermark 保证已聚合的日期不会重复处理;
// - UpsertDailyRollupsFor 内部使用 ON CONFLICT DO UPDATE,同一日重复跑结果一致。
//
// 每一步失败都只记 slog.Warn,整体函数始终返回 nil 让后续步骤能继续跑
// (与 OpsCleanupService.runCleanupOnce 风格一致)。
func
(
s
*
ChannelMonitorService
)
RunDailyMaintenance
(
ctx
context
.
Context
)
error
{
now
:=
time
.
Now
()
.
UTC
()
today
:=
now
.
Truncate
(
24
*
time
.
Hour
)
if
err
:=
s
.
runDailyAggregation
(
ctx
,
today
);
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: maintenance step failed"
,
"step"
,
"aggregate"
,
"error"
,
err
)
}
if
err
:=
s
.
cleanupOldHistory
(
ctx
);
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: maintenance step failed"
,
"step"
,
"prune_history"
,
"error"
,
err
)
}
if
err
:=
s
.
cleanupOldRollups
(
ctx
,
today
);
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: maintenance step failed"
,
"step"
,
"prune_rollups"
,
"error"
,
err
)
}
return
nil
}
// runDailyAggregation 从 watermark+1 聚合到昨天(UTC)。
// 首次跑(watermark nil):从 today-monitorRollupRetentionDays 开始回填。
// 每次最多聚合 monitorMaintenanceMaxDaysPerRun 天,避免长事务。
func
(
s
*
ChannelMonitorService
)
runDailyAggregation
(
ctx
context
.
Context
,
today
time
.
Time
)
error
{
watermark
,
err
:=
s
.
repo
.
LoadAggregationWatermark
(
ctx
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"load watermark: %w"
,
err
)
}
start
:=
s
.
resolveAggregationStart
(
watermark
,
today
)
if
!
start
.
Before
(
today
)
{
return
nil
// 没有需要聚合的日期
}
iterations
:=
0
for
d
:=
start
;
d
.
Before
(
today
);
d
=
d
.
Add
(
24
*
time
.
Hour
)
{
if
iterations
>=
monitorMaintenanceMaxDaysPerRun
{
slog
.
Info
(
"channel_monitor: maintenance aggregation capped"
,
"max_days"
,
monitorMaintenanceMaxDaysPerRun
,
"next_resume"
,
d
.
Format
(
"2006-01-02"
))
break
}
affected
,
upErr
:=
s
.
repo
.
UpsertDailyRollupsFor
(
ctx
,
d
)
if
upErr
!=
nil
{
return
fmt
.
Errorf
(
"upsert rollups for %s: %w"
,
d
.
Format
(
"2006-01-02"
),
upErr
)
}
if
err
:=
s
.
repo
.
UpdateAggregationWatermark
(
ctx
,
d
);
err
!=
nil
{
return
fmt
.
Errorf
(
"update watermark to %s: %w"
,
d
.
Format
(
"2006-01-02"
),
err
)
}
slog
.
Info
(
"channel_monitor: rollups upserted"
,
"date"
,
d
.
Format
(
"2006-01-02"
),
"affected_rows"
,
affected
)
iterations
++
}
return
nil
}
// resolveAggregationStart 计算本次聚合起点:
// - watermark == nil:today - monitorRollupRetentionDays(首次回填最多 30 天)
// - watermark != nil:*watermark + 1 day
func
(
s
*
ChannelMonitorService
)
resolveAggregationStart
(
watermark
*
time
.
Time
,
today
time
.
Time
)
time
.
Time
{
if
watermark
==
nil
{
return
today
.
AddDate
(
0
,
0
,
-
monitorRollupRetentionDays
)
}
return
watermark
.
UTC
()
.
Truncate
(
24
*
time
.
Hour
)
.
Add
(
24
*
time
.
Hour
)
}
// cleanupOldRollups 软删 bucket_date < today - monitorRollupRetentionDays 的日聚合行。
func
(
s
*
ChannelMonitorService
)
cleanupOldRollups
(
ctx
context
.
Context
,
today
time
.
Time
)
error
{
cutoff
:=
today
.
AddDate
(
0
,
0
,
-
monitorRollupRetentionDays
)
deleted
,
err
:=
s
.
repo
.
DeleteRollupsBefore
(
ctx
,
cutoff
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"delete rollups before %s: %w"
,
cutoff
.
Format
(
"2006-01-02"
),
err
)
}
if
deleted
>
0
{
slog
.
Info
(
"channel_monitor: rollups cleanup"
,
"deleted_rows"
,
deleted
,
"before"
,
cutoff
.
Format
(
"2006-01-02"
))
}
return
nil
}
// ---------- helpers ----------
// decryptInPlace 把 ChannelMonitor.APIKey 从密文解密为明文。
// 解密失败时把字段清空 + 设置 APIKeyDecryptFailed=true(不返回错误,避免阻断列表渲染)。
// runner / RunCheck 必须读取该标志位并拒绝执行检测。
func
(
s
*
ChannelMonitorService
)
decryptInPlace
(
m
*
ChannelMonitor
)
{
if
m
==
nil
||
m
.
APIKey
==
""
{
return
}
plain
,
err
:=
s
.
encryptor
.
Decrypt
(
m
.
APIKey
)
if
err
!=
nil
{
slog
.
Warn
(
"channel_monitor: decrypt api key failed"
,
"monitor_id"
,
m
.
ID
,
"error"
,
err
)
m
.
APIKey
=
""
m
.
APIKeyDecryptFailed
=
true
return
}
m
.
APIKey
=
plain
}
// applyMonitorUpdate 把 update params 中非 nil 的字段应用到 existing 上。
// APIKey 字段在调用方单独处理(涉及加密)。
//
// 行数稍超过 30:这是逐字段平铺的 dispatcher,每个 if 都是 1-3 行的"非 nil 则覆盖"模式,
// 拆分反而会增加跳转噪音、影响可读性,故保留为单函数。
func
applyMonitorUpdate
(
existing
*
ChannelMonitor
,
p
ChannelMonitorUpdateParams
)
error
{
if
p
.
Name
!=
nil
{
existing
.
Name
=
strings
.
TrimSpace
(
*
p
.
Name
)
}
if
p
.
Provider
!=
nil
{
if
err
:=
validateProvider
(
*
p
.
Provider
);
err
!=
nil
{
return
err
}
existing
.
Provider
=
*
p
.
Provider
}
if
p
.
Endpoint
!=
nil
{
if
err
:=
validateEndpoint
(
*
p
.
Endpoint
);
err
!=
nil
{
return
err
}
existing
.
Endpoint
=
normalizeEndpoint
(
*
p
.
Endpoint
)
}
if
p
.
PrimaryModel
!=
nil
{
existing
.
PrimaryModel
=
strings
.
TrimSpace
(
*
p
.
PrimaryModel
)
}
if
p
.
ExtraModels
!=
nil
{
existing
.
ExtraModels
=
normalizeModels
(
*
p
.
ExtraModels
)
}
if
p
.
GroupName
!=
nil
{
existing
.
GroupName
=
strings
.
TrimSpace
(
*
p
.
GroupName
)
}
if
p
.
Enabled
!=
nil
{
existing
.
Enabled
=
*
p
.
Enabled
}
if
p
.
IntervalSeconds
!=
nil
{
if
err
:=
validateInterval
(
*
p
.
IntervalSeconds
);
err
!=
nil
{
return
err
}
existing
.
IntervalSeconds
=
*
p
.
IntervalSeconds
}
return
applyMonitorAdvancedUpdate
(
existing
,
p
)
}
// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。
func
applyMonitorAdvancedUpdate
(
existing
*
ChannelMonitor
,
p
ChannelMonitorUpdateParams
)
error
{
if
p
.
ClearTemplate
{
existing
.
TemplateID
=
nil
}
else
if
p
.
TemplateID
!=
nil
{
id
:=
*
p
.
TemplateID
existing
.
TemplateID
=
&
id
}
if
p
.
ExtraHeaders
!=
nil
{
if
err
:=
validateExtraHeaders
(
*
p
.
ExtraHeaders
);
err
!=
nil
{
return
err
}
existing
.
ExtraHeaders
=
emptyHeadersIfNil
(
*
p
.
ExtraHeaders
)
}
// BodyOverrideMode / BodyOverride 联合校验,和模板一致。
newMode
:=
existing
.
BodyOverrideMode
newBody
:=
existing
.
BodyOverride
if
p
.
BodyOverrideMode
!=
nil
{
newMode
=
*
p
.
BodyOverrideMode
}
if
p
.
BodyOverride
!=
nil
{
newBody
=
*
p
.
BodyOverride
}
if
p
.
BodyOverrideMode
!=
nil
||
p
.
BodyOverride
!=
nil
{
if
err
:=
validateBodyModeParams
(
newMode
,
newBody
);
err
!=
nil
{
return
err
}
existing
.
BodyOverrideMode
=
defaultBodyMode
(
newMode
)
existing
.
BodyOverride
=
newBody
}
return
nil
}
backend/internal/service/channel_monitor_ssrf.go
0 → 100644
View file @
ac114738
package
service
import
(
"context"
"net"
"strings"
)
// SSRF 防护 helper:
// - validateEndpoint 在 admin 提交时阻止 http/loopback/私网/云元数据 URL
// - safeDialContext 在 socket 层再次校验真实 IP,防止 DNS rebinding
//
// 已知 cloud metadata hostname 拒绝列表(小写比较)。
var
monitorBlockedHostnames
=
map
[
string
]
struct
{}{
"localhost"
:
{},
"localhost.localdomain"
:
{},
"metadata"
:
{},
"metadata.google.internal"
:
{},
"metadata.goog"
:
{},
"instance-data"
:
{},
"instance-data.ec2.internal"
:
{},
}
// CIDR 列表:包含所有需要拒绝的 IPv4/IPv6 段。
// 解析时只 panic 一次(启动时确认),生产路径只做 Contains。
var
monitorBlockedCIDRs
=
mustParseCIDRs
([]
string
{
"127.0.0.0/8"
,
// IPv4 loopback
"10.0.0.0/8"
,
// RFC1918
"172.16.0.0/12"
,
// RFC1918
"192.168.0.0/16"
,
// RFC1918
"169.254.0.0/16"
,
// link-local(含云元数据 169.254.169.254)
"100.64.0.0/10"
,
// CGNAT
"0.0.0.0/8"
,
// "this network"
"::1/128"
,
// IPv6 loopback
"fc00::/7"
,
// IPv6 ULA
"fe80::/10"
,
// IPv6 link-local
"::/128"
,
// IPv6 unspecified
})
// monitorDialer 共享 Dialer,与 net/http 默认值对齐。
var
monitorDialer
=
&
net
.
Dialer
{
Timeout
:
monitorDialTimeout
,
KeepAlive
:
monitorDialKeepAlive
,
}
// mustParseCIDRs 在包初始化时解析 CIDR 字符串,失败 panic。
func
mustParseCIDRs
(
cidrs
[]
string
)
[]
*
net
.
IPNet
{
out
:=
make
([]
*
net
.
IPNet
,
0
,
len
(
cidrs
))
for
_
,
c
:=
range
cidrs
{
_
,
n
,
err
:=
net
.
ParseCIDR
(
c
)
if
err
!=
nil
{
panic
(
"channel_monitor_ssrf: invalid CIDR "
+
c
+
": "
+
err
.
Error
())
}
out
=
append
(
out
,
n
)
}
return
out
}
// isBlockedHostname 判断 hostname 是否命中黑名单。
func
isBlockedHostname
(
hostname
string
)
bool
{
if
hostname
==
""
{
return
true
}
_
,
blocked
:=
monitorBlockedHostnames
[
strings
.
ToLower
(
hostname
)]
return
blocked
}
// isPrivateIP 判断 IP 是否落在禁止段(loopback/RFC1918/link-local/ULA 等)。
func
isPrivateIP
(
ip
net
.
IP
)
bool
{
if
ip
==
nil
{
return
true
}
if
ip
.
IsUnspecified
()
||
ip
.
IsLoopback
()
||
ip
.
IsLinkLocalUnicast
()
||
ip
.
IsLinkLocalMulticast
()
||
ip
.
IsInterfaceLocalMulticast
()
{
return
true
}
for
_
,
n
:=
range
monitorBlockedCIDRs
{
if
n
.
Contains
(
ip
)
{
return
true
}
}
return
false
}
// isPrivateOrLoopbackHost 解析 hostname 的所有 A/AAAA 记录,
// 任一 IP 落在私网/loopback 段即认为不安全。
//
// hostname 是 IP 字面量时也走同一路径。
func
isPrivateOrLoopbackHost
(
ctx
context
.
Context
,
hostname
string
)
(
bool
,
error
)
{
if
isBlockedHostname
(
hostname
)
{
return
true
,
nil
}
// IP 字面量直接判断。
if
ip
:=
net
.
ParseIP
(
hostname
);
ip
!=
nil
{
return
isPrivateIP
(
ip
),
nil
}
resolver
:=
net
.
DefaultResolver
addrs
,
err
:=
resolver
.
LookupIPAddr
(
ctx
,
hostname
)
if
err
!=
nil
{
return
false
,
err
}
if
len
(
addrs
)
==
0
{
return
true
,
nil
}
for
_
,
a
:=
range
addrs
{
if
isPrivateIP
(
a
.
IP
)
{
return
true
,
nil
}
}
return
false
,
nil
}
// safeDialContext 在真实 dial 前再次校验目标 IP,防止 DNS rebinding。
// 解析 hostname 后逐个 IP 尝试连接,命中私网即拒绝(即便 validateEndpoint 时返回的是公网 IP)。
func
safeDialContext
(
ctx
context
.
Context
,
network
,
address
string
)
(
net
.
Conn
,
error
)
{
host
,
port
,
err
:=
net
.
SplitHostPort
(
address
)
if
err
!=
nil
{
return
nil
,
err
}
// 字面量 IP 走快速路径。
if
ip
:=
net
.
ParseIP
(
host
);
ip
!=
nil
{
if
isPrivateIP
(
ip
)
{
return
nil
,
&
net
.
AddrError
{
Err
:
"blocked by SSRF policy"
,
Addr
:
address
}
}
return
monitorDialer
.
DialContext
(
ctx
,
network
,
address
)
}
if
isBlockedHostname
(
host
)
{
return
nil
,
&
net
.
AddrError
{
Err
:
"blocked by SSRF policy"
,
Addr
:
address
}
}
addrs
,
err
:=
net
.
DefaultResolver
.
LookupIPAddr
(
ctx
,
host
)
if
err
!=
nil
{
return
nil
,
err
}
if
len
(
addrs
)
==
0
{
return
nil
,
&
net
.
AddrError
{
Err
:
"no addresses for host"
,
Addr
:
host
}
}
var
lastErr
error
for
_
,
a
:=
range
addrs
{
if
isPrivateIP
(
a
.
IP
)
{
lastErr
=
&
net
.
AddrError
{
Err
:
"blocked by SSRF policy"
,
Addr
:
a
.
IP
.
String
()}
continue
}
conn
,
err
:=
monitorDialer
.
DialContext
(
ctx
,
network
,
net
.
JoinHostPort
(
a
.
IP
.
String
(),
port
))
if
err
==
nil
{
return
conn
,
nil
}
lastErr
=
err
}
if
lastErr
==
nil
{
lastErr
=
&
net
.
AddrError
{
Err
:
"no usable addresses"
,
Addr
:
host
}
}
return
nil
,
lastErr
}
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment