Unverified Commit ac114738 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1850 from touwaeriol/feat/channel-insights

feat(monitor): channel monitor with available channels & feature flags
parents 0a80ec80 09fd83ab
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。
// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率
// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。
// 超过保留期由每日维护任务分批物理删(不用软删除,理由同 channel_monitor_history)。
type ChannelMonitorDailyRollup struct {
ent.Schema
}
func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "channel_monitor_daily_rollups"},
}
}
func (ChannelMonitorDailyRollup) Fields() []ent.Field {
return []ent.Field{
field.Int64("monitor_id"),
field.String("model").
NotEmpty().
MaxLen(200),
field.Time("bucket_date").
SchemaType(map[string]string{dialect.Postgres: "date"}),
field.Int("total_checks").Default(0),
field.Int("ok_count").Default(0),
field.Int("operational_count").Default(0),
field.Int("degraded_count").Default(0),
field.Int("failed_count").Default(0),
field.Int("error_count").Default(0),
field.Int64("sum_latency_ms").Default(0),
field.Int("count_latency").Default(0),
field.Int64("sum_ping_latency_ms").Default(0),
field.Int("count_ping_latency").Default(0),
field.Time("computed_at").Default(time.Now).UpdateDefault(time.Now),
}
}
func (ChannelMonitorDailyRollup) Edges() []ent.Edge {
return []ent.Edge{
edge.From("monitor", ChannelMonitor.Type).
Ref("daily_rollups").
Field("monitor_id").
Unique().
Required(),
}
}
func (ChannelMonitorDailyRollup) Indexes() []ent.Index {
return []ent.Index{
index.Fields("monitor_id", "model", "bucket_date").Unique(),
index.Fields("bucket_date"),
}
}
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity.
// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天由每日维护任务
// 先聚合到 channel_monitor_daily_rollups,再分批物理删(不用软删除:日志类表无恢复
// 需求,软删会让行和索引只增不减,徒增磁盘和查询开销)。
type ChannelMonitorHistory struct {
ent.Schema
}
func (ChannelMonitorHistory) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "channel_monitor_histories"},
}
}
func (ChannelMonitorHistory) Fields() []ent.Field {
return []ent.Field{
field.Int64("monitor_id"),
field.String("model").
NotEmpty().
MaxLen(200),
field.Enum("status").
Values("operational", "degraded", "failed", "error"),
field.Int("latency_ms").
Optional().
Nillable(),
field.Int("ping_latency_ms").
Optional().
Nillable(),
field.String("message").
Optional().
Default("").
MaxLen(500),
field.Time("checked_at").
Default(time.Now),
}
}
func (ChannelMonitorHistory) Edges() []ent.Edge {
return []ent.Edge{
edge.From("monitor", ChannelMonitor.Type).
Ref("history").
Field("monitor_id").
Unique().
Required(),
}
}
func (ChannelMonitorHistory) Indexes() []ent.Index {
return []ent.Index{
index.Fields("monitor_id", "model", "checked_at"),
index.Fields("checked_at"),
}
}
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// ChannelMonitorRequestTemplate 请求模板:一组可复用的 headers + 可选 body 覆盖配置。
//
// 语义为快照:模板被"应用"到监控时,extra_headers / body_override_mode / body_override
// 会被**拷贝**到 channel_monitors 同名字段;后续模板变动不会自动影响已应用的监控——
// 必须用户主动在模板编辑 Dialog 里点「应用到关联监控」才会覆盖快照。
// 这样模板改错不会瞬间打挂所有已经跑起来的监控。
type ChannelMonitorRequestTemplate struct {
ent.Schema
}
func (ChannelMonitorRequestTemplate) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "channel_monitor_request_templates"},
}
}
func (ChannelMonitorRequestTemplate) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
return []ent.Field{
field.String("name").
NotEmpty().
MaxLen(100),
field.Enum("provider").
Values("openai", "anthropic", "gemini"),
field.String("description").
Optional().
Default("").
MaxLen(500),
// extra_headers: 用户自定义 HTTP 头(如 User-Agent 伪装)。
// 运行时 merge 进 adapter 默认 headers,用户值优先;
// hop-by-hop 黑名单(Host/Content-Length/...)由 checker 过滤。
field.JSON("extra_headers", map[string]string{}).
Default(map[string]string{}),
// body_override_mode: 'off' | 'merge' | 'replace'
// off - 用 adapter 默认 body(忽略 body_override)
// merge - adapter 默认 body 与 body_override 浅合并(body_override 优先,
// model/messages/contents 等关键字段在 checker 里走黑名单跳过)
// replace - 直接用 body_override 作为完整 body;此时跳过 challenge 校验,
// 改为 HTTP 2xx + 响应文本非空即视为可用
field.String("body_override_mode").
Default("off").
MaxLen(10),
// body_override: JSON 对象,根据 body_override_mode 使用。
// 用 map[string]any 以便前端传任意结构(含嵌套)。
field.JSON("body_override", map[string]any{}).
Optional(),
}
}
func (ChannelMonitorRequestTemplate) Edges() []ent.Edge {
return []ent.Edge{
edge.From("monitors", ChannelMonitor.Type).
Ref("request_template"),
}
}
func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
return []ent.Index{
// 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
index.Fields("provider", "name").Unique(),
}
}
......@@ -28,6 +28,14 @@ type Tx struct {
AuthIdentity *AuthIdentityClient
// AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
AuthIdentityChannel *AuthIdentityChannelClient
// ChannelMonitor is the client for interacting with the ChannelMonitor builders.
ChannelMonitor *ChannelMonitorClient
// ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
// ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
ChannelMonitorHistory *ChannelMonitorHistoryClient
// ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
......@@ -212,6 +220,10 @@ func (tx *Tx) init() {
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.AuthIdentity = NewAuthIdentityClient(tx.config)
tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
......
......@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
......
......@@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = service.BillingModelSourceChannelMapped
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
......
......@@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
BillingModelSource: service.BillingModelSourceChannelMapped,
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
......@@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
},
}
// handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底
// 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理),
// 因此这里构造 fixture 时直接传入归一化后的值。
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
......@@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) {
// handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
resp := channelToResponse(ch)
require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责")
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// monitorMaxPageSize 列表分页上限。
monitorMaxPageSize = 100
// monitorAPIKeyMaskPrefix 脱敏时保留的明文前缀长度。
monitorAPIKeyMaskPrefix = 4
// monitorAPIKeyMaskSuffix 脱敏后追加的占位字符串。
monitorAPIKeyMaskSuffix = "***"
)
// ChannelMonitorHandler 渠道监控管理后台 handler。
type ChannelMonitorHandler struct {
monitorService *service.ChannelMonitorService
}
// NewChannelMonitorHandler 创建 handler。
func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorHandler {
return &ChannelMonitorHandler{monitorService: monitorService}
}
// --- Request / Response ---
type channelMonitorCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
Endpoint string `json:"endpoint" binding:"required,max=500"`
APIKey string `json:"api_key" binding:"required,max=2000"`
PrimaryModel string `json:"primary_model" binding:"required,max=200"`
ExtraModels []string `json:"extra_models"`
GroupName string `json:"group_name" binding:"max=100"`
Enabled *bool `json:"enabled"`
IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
TemplateID *int64 `json:"template_id"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
ExtraModels *[]string `json:"extra_models"`
GroupName *string `json:"group_name" binding:"omitempty,max=100"`
Enabled *bool `json:"enabled"`
IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
TemplateID *int64 `json:"template_id"`
ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride *map[string]any `json:"body_override"`
}
type channelMonitorResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
APIKeyMasked string `json:"api_key_masked"`
APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
PrimaryModel string `json:"primary_model"`
ExtraModels []string `json:"extra_models"`
GroupName string `json:"group_name"`
Enabled bool `json:"enabled"`
IntervalSeconds int `json:"interval_seconds"`
LastCheckedAt *string `json:"last_checked_at"`
CreatedBy int64 `json:"created_by"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
PrimaryStatus string `json:"primary_status"`
PrimaryLatencyMs *int `json:"primary_latency_ms"`
Availability7d float64 `json:"availability_7d"`
ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"`
// 请求自定义快照:前端编辑 / 展示「高级设置」用
TemplateID *int64 `json:"template_id"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorCheckResultResponse struct {
Model string `json:"model"`
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
PingLatencyMs *int `json:"ping_latency_ms"`
Message string `json:"message"`
CheckedAt string `json:"checked_at"`
}
type channelMonitorHistoryItemResponse struct {
ID int64 `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
PingLatencyMs *int `json:"ping_latency_ms"`
Message string `json:"message"`
CheckedAt string `json:"checked_at"`
}
// maskAPIKey 对 API Key 明文做脱敏:前 4 字符 + "***",长度 ≤ 4 时只显示 "***"。
func maskAPIKey(plain string) string {
if len(plain) <= monitorAPIKeyMaskPrefix {
return monitorAPIKeyMaskSuffix
}
return plain[:monitorAPIKeyMaskPrefix] + monitorAPIKeyMaskSuffix
}
func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse {
if m == nil {
return nil
}
extras := m.ExtraModels
if extras == nil {
extras = []string{}
}
headers := m.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
resp := &channelMonitorResponse{
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
Endpoint: m.Endpoint,
APIKeyMasked: maskAPIKey(m.APIKey),
APIKeyDecryptFailed: m.APIKeyDecryptFailed,
PrimaryModel: m.PrimaryModel,
ExtraModels: extras,
GroupName: m.GroupName,
Enabled: m.Enabled,
IntervalSeconds: m.IntervalSeconds,
CreatedBy: m.CreatedBy,
CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339),
UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339),
TemplateID: m.TemplateID,
ExtraHeaders: headers,
BodyOverrideMode: m.BodyOverrideMode,
BodyOverride: m.BodyOverride,
// PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。
}
if m.LastCheckedAt != nil {
s := m.LastCheckedAt.UTC().Format(time.RFC3339)
resp.LastCheckedAt = &s
}
return resp
}
func checkResultToResponse(r *service.CheckResult) channelMonitorCheckResultResponse {
return channelMonitorCheckResultResponse{
Model: r.Model,
Status: r.Status,
LatencyMs: r.LatencyMs,
PingLatencyMs: r.PingLatencyMs,
Message: r.Message,
CheckedAt: r.CheckedAt.UTC().Format(time.RFC3339),
}
}
func historyEntryToResponse(e *service.ChannelMonitorHistoryEntry) channelMonitorHistoryItemResponse {
return channelMonitorHistoryItemResponse{
ID: e.ID,
Model: e.Model,
Status: e.Status,
LatencyMs: e.LatencyMs,
PingLatencyMs: e.PingLatencyMs,
Message: e.Message,
CheckedAt: e.CheckedAt.UTC().Format(time.RFC3339),
}
}
// ParseChannelMonitorID 提取并校验路径参数 :id(admin 与 user handler 共享)。
// 校验失败时已写入 4xx 响应,调用方只需 return。
func ParseChannelMonitorID(c *gin.Context) (int64, bool) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_MONITOR_ID", "invalid monitor id"))
return 0, false
}
return id, true
}
// parseListEnabled 解析 enabled query 参数:true/false 转为 *bool,空或非法则返回 nil。
func parseListEnabled(raw string) *bool {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "true", "1", "yes":
v := true
return &v
case "false", "0", "no":
v := false
return &v
default:
return nil
}
}
// --- Handlers ---
// List GET /api/v1/admin/channel-monitors
func (h *ChannelMonitorHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
if pageSize > monitorMaxPageSize {
pageSize = monitorMaxPageSize
}
params := service.ChannelMonitorListParams{
Page: page,
PageSize: pageSize,
Provider: strings.TrimSpace(c.Query("provider")),
Enabled: parseListEnabled(c.Query("enabled")),
Search: strings.TrimSpace(c.Query("search")),
}
items, total, err := h.monitorService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
summaries := h.batchSummaryFor(c, items)
out := make([]*channelMonitorResponse, 0, len(items))
for _, m := range items {
out = append(out, buildListItemResponse(m, summaries[m.ID]))
}
response.Paginated(c, out, total, page, pageSize)
}
// batchSummaryFor 批量聚合 latest + 7d 可用率,避免每行 2 次 SQL(消除 N+1)。
func (h *ChannelMonitorHandler) batchSummaryFor(c *gin.Context, items []*service.ChannelMonitor) map[int64]service.MonitorStatusSummary {
ids := make([]int64, 0, len(items))
primaryByID := make(map[int64]string, len(items))
extrasByID := make(map[int64][]string, len(items))
for _, m := range items {
ids = append(ids, m.ID)
primaryByID[m.ID] = m.PrimaryModel
extrasByID[m.ID] = m.ExtraModels
}
return h.monitorService.BatchMonitorStatusSummary(c.Request.Context(), ids, primaryByID, extrasByID)
}
// buildListItemResponse 把 monitor + summary 装成 admin list 的响应行。
func buildListItemResponse(m *service.ChannelMonitor, summary service.MonitorStatusSummary) *channelMonitorResponse {
resp := channelMonitorToResponse(m)
resp.PrimaryStatus = summary.PrimaryStatus
resp.PrimaryLatencyMs = summary.PrimaryLatencyMs
resp.Availability7d = summary.Availability7d
resp.ExtraModelsStatus = make([]dto.ChannelMonitorExtraModelStatus, 0, len(summary.ExtraModels))
for _, e := range summary.ExtraModels {
resp.ExtraModelsStatus = append(resp.ExtraModelsStatus, dto.ChannelMonitorExtraModelStatus{
Model: e.Model,
Status: e.Status,
LatencyMs: e.LatencyMs,
})
}
return resp
}
// Get GET /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Get(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
m, err := h.monitorService.Get(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelMonitorToResponse(m))
}
// Create POST /api/v1/admin/channel-monitors
func (h *ChannelMonitorHandler) Create(c *gin.Context) {
var req channelMonitorCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
subject, _ := middleware2.GetAuthSubjectFromContext(c)
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
Name: req.Name,
Provider: req.Provider,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
ExtraModels: req.ExtraModels,
GroupName: req.GroupName,
Enabled: enabled,
IntervalSeconds: req.IntervalSeconds,
CreatedBy: subject.UserID,
TemplateID: req.TemplateID,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Created(c, channelMonitorToResponse(m))
}
// Update PUT /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Update(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
var req channelMonitorUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
Name: req.Name,
Provider: req.Provider,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
ExtraModels: req.ExtraModels,
GroupName: req.GroupName,
Enabled: req.Enabled,
IntervalSeconds: req.IntervalSeconds,
TemplateID: req.TemplateID,
ClearTemplate: req.ClearTemplate,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, channelMonitorToResponse(m))
}
// Delete DELETE /api/v1/admin/channel-monitors/:id
func (h *ChannelMonitorHandler) Delete(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
if err := h.monitorService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// Run POST /api/v1/admin/channel-monitors/:id/run
func (h *ChannelMonitorHandler) Run(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
results, err := h.monitorService.RunCheck(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]channelMonitorCheckResultResponse, 0, len(results))
for _, r := range results {
out = append(out, checkResultToResponse(r))
}
response.Success(c, gin.H{"results": out})
}
// History GET /api/v1/admin/channel-monitors/:id/history
func (h *ChannelMonitorHandler) History(c *gin.Context) {
id, ok := ParseChannelMonitorID(c)
if !ok {
return
}
limit := parseHistoryLimit(c.Query("limit"))
model := strings.TrimSpace(c.Query("model"))
entries, err := h.monitorService.ListHistory(c.Request.Context(), id, model, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]channelMonitorHistoryItemResponse, 0, len(entries))
for _, e := range entries {
out = append(out, historyEntryToResponse(e))
}
response.Success(c, gin.H{"items": out})
}
// parseHistoryLimit 解析 history 接口的 limit query。
// 使用 service 包的统一上下限常量,避免在 handler 重复定义同名魔法值。
func parseHistoryLimit(raw string) int {
if strings.TrimSpace(raw) == "" {
return service.MonitorHistoryDefaultLimit
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
return service.MonitorHistoryDefaultLimit
}
if v > service.MonitorHistoryMaxLimit {
return service.MonitorHistoryMaxLimit
}
return v
}
package admin
import (
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。
type ChannelMonitorRequestTemplateHandler struct {
templateService *service.ChannelMonitorRequestTemplateService
}
// NewChannelMonitorRequestTemplateHandler 创建 handler。
func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler {
return &ChannelMonitorRequestTemplateHandler{templateService: templateService}
}
// --- DTO ---
type channelMonitorTemplateCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
Description string `json:"description" binding:"max=500"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride map[string]any `json:"body_override"`
}
type channelMonitorTemplateUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description" binding:"omitempty,max=500"`
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
BodyOverride *map[string]any `json:"body_override"`
}
type channelMonitorTemplateResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Description string `json:"description"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
BodyOverride map[string]any `json:"body_override"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
AssociatedMonitors int64 `json:"associated_monitors"`
}
func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse {
if t == nil {
return nil
}
headers := t.ExtraHeaders
if headers == nil {
headers = map[string]string{}
}
count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID)
return &channelMonitorTemplateResponse{
ID: t.ID,
Name: t.Name,
Provider: t.Provider,
Description: t.Description,
ExtraHeaders: headers,
BodyOverrideMode: t.BodyOverrideMode,
BodyOverride: t.BodyOverride,
CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339),
AssociatedMonitors: count,
}
}
// parseTemplateID 提取并校验 :id。
func parseTemplateID(c *gin.Context) (int64, bool) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id"))
return 0, false
}
return id, true
}
// --- Handlers ---
// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic
func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
Provider: strings.TrimSpace(c.Query("provider")),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*channelMonitorTemplateResponse, 0, len(items))
for _, t := range items {
out = append(out, h.toResponse(c, t))
}
response.Success(c, gin.H{"items": out})
}
// Get GET /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
t, err := h.templateService.Get(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.toResponse(c, t))
}
// Create POST /api/v1/admin/channel-monitor-templates
func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
var req channelMonitorTemplateCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
Name: req.Name,
Provider: req.Provider,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Created(c, h.toResponse(c, t))
}
// Update PUT /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
var req channelMonitorTemplateUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
Name: req.Name,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
BodyOverride: req.BodyOverride,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.toResponse(c, t))
}
// Delete DELETE /api/v1/admin/channel-monitor-templates/:id
func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
if err := h.templateService.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
type channelMonitorTemplateApplyRequest struct {
// MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。
// 仅当对应监控当前 template_id == :id 时才会真的被覆盖。
MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"`
}
// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。
func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
var req channelMonitorTemplateApplyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"affected": affected})
}
type associatedMonitorBriefResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
}
// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors
// 列出关联监控(picker 弹窗用)。
func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) {
id, ok := parseTemplateID(c)
if !ok {
return
}
items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]associatedMonitorBriefResponse, 0, len(items))
for _, m := range items {
out = append(out, associatedMonitorBriefResponse{
ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
})
}
response.Success(c, gin.H{"items": out})
}
......@@ -236,6 +236,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
......@@ -427,6 +432,13 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
// Channel Monitor feature switch
ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
// Available Channels feature switch (user-facing)
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
}
// UpdateSettings 更新系统设置
......@@ -1222,6 +1234,24 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AccountQuotaNotifyEmails
}(),
ChannelMonitorEnabled: func() bool {
if req.ChannelMonitorEnabled != nil {
return *req.ChannelMonitorEnabled
}
return previousSettings.ChannelMonitorEnabled
}(),
ChannelMonitorDefaultIntervalSeconds: func() int {
if req.ChannelMonitorDefaultIntervalSeconds != nil {
return *req.ChannelMonitorDefaultIntervalSeconds
}
return previousSettings.ChannelMonitorDefaultIntervalSeconds
}(),
AvailableChannelsEnabled: func() bool {
if req.AvailableChannelsEnabled != nil {
return *req.AvailableChannelsEnabled
}
return previousSettings.AvailableChannelsEnabled
}(),
}
authSourceDefaults := &service.AuthSourceDefaultSettings{
......@@ -1453,6 +1483,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
......@@ -1809,6 +1844,15 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled {
changed = append(changed, "channel_monitor_enabled")
}
if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds {
changed = append(changed, "channel_monitor_default_interval_seconds")
}
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
changed = append(changed, "available_channels_enabled")
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
......
package handler
import (
"sort"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AvailableChannelHandler 处理用户侧「可用渠道」查询。
//
// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤:
// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏;
// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
// / 内部 ID / Status 等管理字段)。
type AvailableChannelHandler struct {
channelService *service.ChannelService
apiKeyService *service.APIKeyService
settingService *service.SettingService
}
// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
func NewAvailableChannelHandler(
channelService *service.ChannelService,
apiKeyService *service.APIKeyService,
settingService *service.SettingService,
) *AvailableChannelHandler {
return &AvailableChannelHandler{
channelService: channelService,
apiKeyService: apiKeyService,
settingService: settingService,
}
}
// featureEnabled 返回 available-channels 开关是否启用。默认关闭(opt-in)。
func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool {
if h.settingService == nil {
return false
}
return h.settingService.GetAvailableChannelsRuntime(c.Request.Context()).Enabled
}
// userAvailableGroup 用户可见的分组概要(白名单字段)。
//
// 前端据此区分专属 vs 公开分组(IsExclusive)、订阅 vs 标准分组(SubscriptionType,
// 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走
// /groups/rates,和 API 密钥页面保持一致。
type userAvailableGroup struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
}
// userSupportedModelPricing 用户可见的定价字段白名单。
type userSupportedModelPricing struct {
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []userPricingIntervalDTO `json:"intervals"`
}
// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
type userPricingIntervalDTO struct {
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
}
// userSupportedModel 用户可见的支持模型条目。
type userSupportedModel struct {
Name string `json:"name"`
Platform string `json:"platform"`
Pricing *userSupportedModelPricing `json:"pricing"`
}
// userChannelPlatformSection 单渠道内某个平台的子视图:用户可见的分组 + 该平台
// 支持的模型。按 platform 聚合后让前端可以把渠道名作为 row-group 一次渲染,
// 后面的平台行按 sections 顺序铺开。
type userChannelPlatformSection struct {
Platform string `json:"platform"`
Groups []userAvailableGroup `json:"groups"`
SupportedModels []userSupportedModel `json:"supported_models"`
}
// userAvailableChannel 用户可见的渠道条目(白名单字段)。
//
// 每个渠道聚合为一条记录,内嵌 platforms 子数组:每个 section 对应一个平台,
// 包含该平台的 groups 和 supported_models。
type userAvailableChannel struct {
Name string `json:"name"`
Description string `json:"description"`
Platforms []userChannelPlatformSection `json:"platforms"`
}
// List 列出当前用户可见的「可用渠道」。
// GET /api/v1/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
// Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后,
// 保持与未开关前的 401 行为一致:未登录先 401,登录后再按开关决定。
if !h.featureEnabled(c) {
response.Success(c, []userAvailableChannel{})
return
}
userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
for i := range userGroups {
allowedGroupIDs[userGroups[i].ID] = struct{}{}
}
channels, err := h.channelService.ListAvailable(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]userAvailableChannel, 0, len(channels))
for _, ch := range channels {
if ch.Status != service.StatusActive {
continue
}
visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
if len(visibleGroups) == 0 {
continue
}
sections := buildPlatformSections(ch, visibleGroups)
if len(sections) == 0 {
continue
}
out = append(out, userAvailableChannel{
Name: ch.Name,
Description: ch.Description,
Platforms: sections,
})
}
response.Success(c, out)
}
// buildPlatformSections 把一个渠道按 visibleGroups 的平台集合拆成有序的 section 列表:
// 每个 section 对应一个平台,只包含该平台的 groups 和 supported_models。
// 输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。
func buildPlatformSections(
ch service.AvailableChannel,
visibleGroups []userAvailableGroup,
) []userChannelPlatformSection {
groupsByPlatform := make(map[string][]userAvailableGroup, 4)
for _, g := range visibleGroups {
if g.Platform == "" {
continue
}
groupsByPlatform[g.Platform] = append(groupsByPlatform[g.Platform], g)
}
if len(groupsByPlatform) == 0 {
return nil
}
platforms := make([]string, 0, len(groupsByPlatform))
for p := range groupsByPlatform {
platforms = append(platforms, p)
}
sort.Strings(platforms)
sections := make([]userChannelPlatformSection, 0, len(platforms))
for _, platform := range platforms {
platformSet := map[string]struct{}{platform: {}}
sections = append(sections, userChannelPlatformSection{
Platform: platform,
Groups: groupsByPlatform[platform],
SupportedModels: toUserSupportedModels(ch.SupportedModels, platformSet),
})
}
return sections
}
// filterUserVisibleGroups 仅保留用户可访问的分组。
func filterUserVisibleGroups(
groups []service.AvailableGroupRef,
allowed map[int64]struct{},
) []userAvailableGroup {
visible := make([]userAvailableGroup, 0, len(groups))
for _, g := range groups {
if _, ok := allowed[g.ID]; !ok {
continue
}
visible = append(visible, userAvailableGroup{
ID: g.ID,
Name: g.Name,
Platform: g.Platform,
SubscriptionType: g.SubscriptionType,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
})
}
return visible
}
// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。
// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
func toUserSupportedModels(
src []service.SupportedModel,
allowedPlatforms map[string]struct{},
) []userSupportedModel {
out := make([]userSupportedModel, 0, len(src))
for i := range src {
m := src[i]
if allowedPlatforms != nil {
if _, ok := allowedPlatforms[m.Platform]; !ok {
continue
}
}
out = append(out, userSupportedModel{
Name: m.Name,
Platform: m.Platform,
Pricing: toUserPricing(m.Pricing),
})
}
return out
}
// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。
func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
if p == nil {
return nil
}
intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
for _, iv := range p.Intervals {
intervals = append(intervals, userPricingIntervalDTO{
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
})
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = string(service.BillingModeToken)
}
return &userSupportedModelPricing{
BillingMode: billingMode,
InputPrice: p.InputPrice,
OutputPrice: p.OutputPrice,
CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals,
}
}
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
// 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。
gin.SetMode(gin.TestMode)
h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil)
h.List(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
}
func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
// 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
groups := []service.AvailableGroupRef{
{ID: 1, Name: "g1", Platform: "anthropic"},
{ID: 2, Name: "g2", Platform: "anthropic"},
{ID: 3, Name: "g3", Platform: "openai"},
}
allowed := map[int64]struct{}{1: {}, 3: {}}
visible := filterUserVisibleGroups(groups, allowed)
require.Len(t, visible, 2)
ids := []int64{visible[0].ID, visible[1].ID}
require.ElementsMatch(t, []int64{1, 3}, ids)
}
func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
// 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
src := []service.SupportedModel{
{Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil},
{Name: "gpt-4o", Platform: "openai", Pricing: nil},
}
allowed := map[string]struct{}{"anthropic": {}}
out := toUserSupportedModels(src, allowed)
require.Len(t, out, 1)
require.Equal(t, "claude-sonnet-4-6", out[0].Name)
}
func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) {
// 显式传 nil allowedPlatforms 表示不做过滤。
src := []service.SupportedModel{
{Name: "a", Platform: "anthropic"},
{Name: "b", Platform: "openai"},
}
require.Len(t, toUserSupportedModels(src, nil), 2)
}
func TestUserAvailableChannel_FieldWhitelist(t *testing.T) {
// 通过序列化 userAvailableChannel 结构体验证响应形状:
// 只有 name / description / platforms;不含管理端字段。
row := userAvailableChannel{
Name: "ch",
Description: "d",
Platforms: []userChannelPlatformSection{
{
Platform: "anthropic",
Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}},
SupportedModels: []userSupportedModel{},
},
},
}
raw, err := json.Marshal(row)
require.NoError(t, err)
var decoded map[string]any
require.NoError(t, json.Unmarshal(raw, &decoded))
for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} {
_, exists := decoded[key]
require.Falsef(t, exists, "user DTO must not expose %q", key)
}
for _, key := range []string{"name", "description", "platforms"} {
_, exists := decoded[key]
require.Truef(t, exists, "user DTO must expose %q", key)
}
// 验证 section 的字段(platform / groups / supported_models)。
rawSection, err := json.Marshal(row.Platforms[0])
require.NoError(t, err)
var sectionDecoded map[string]any
require.NoError(t, json.Unmarshal(rawSection, &sectionDecoded))
for _, key := range []string{"platform", "groups", "supported_models"} {
_, exists := sectionDecoded[key]
require.Truef(t, exists, "platform section must expose %q", key)
}
// Group DTO 暴露区分专属/公开、订阅类型、默认倍率所需的字段,
// 前端据此渲染 GroupBadge 并与 API 密钥页保持一致的视觉。
rawGroup, err := json.Marshal(row.Platforms[0].Groups[0])
require.NoError(t, err)
var groupDecoded map[string]any
require.NoError(t, json.Unmarshal(rawGroup, &groupDecoded))
for _, key := range []string{"id", "name", "platform", "subscription_type", "rate_multiplier", "is_exclusive"} {
_, exists := groupDecoded[key]
require.Truef(t, exists, "group DTO must expose %q", key)
}
// pricing interval 白名单:不应暴露 id / sort_order。
pricing := toUserPricing(&service.ChannelModelPricing{
BillingMode: service.BillingModeToken,
Intervals: []service.PricingInterval{
{ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3},
},
})
require.NotNil(t, pricing)
require.Len(t, pricing.Intervals, 1)
rawIv, err := json.Marshal(pricing.Intervals[0])
require.NoError(t, err)
var ivDecoded map[string]any
require.NoError(t, json.Unmarshal(rawIv, &ivDecoded))
for _, key := range []string{"id", "pricing_id", "sort_order"} {
_, exists := ivDecoded[key]
require.Falsef(t, exists, "user pricing interval must not expose %q", key)
}
}
func TestBuildPlatformSections_GroupsByPlatform(t *testing.T) {
// 一个渠道横跨 anthropic / openai / 空平台:应该生成 2 个 section,
// 按 platform 字母序排序,各自 groups 和 supported_models 只含同平台条目。
ch := service.AvailableChannel{
Name: "ch",
SupportedModels: []service.SupportedModel{
{Name: "claude-sonnet-4-6", Platform: "anthropic"},
{Name: "gpt-4o", Platform: "openai"},
},
}
visible := []userAvailableGroup{
{ID: 1, Name: "g-openai", Platform: "openai"},
{ID: 2, Name: "g-ant", Platform: "anthropic"},
{ID: 3, Name: "g-empty", Platform: ""},
}
sections := buildPlatformSections(ch, visible)
require.Len(t, sections, 2)
require.Equal(t, "anthropic", sections[0].Platform)
require.Equal(t, "openai", sections[1].Platform)
require.Len(t, sections[0].Groups, 1)
require.Equal(t, int64(2), sections[0].Groups[0].ID)
require.Len(t, sections[0].SupportedModels, 1)
require.Equal(t, "claude-sonnet-4-6", sections[0].SupportedModels[0].Name)
}
package handler
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChannelMonitorUserHandler 渠道监控用户只读 handler。
type ChannelMonitorUserHandler struct {
monitorService *service.ChannelMonitorService
settingService *service.SettingService
}
// NewChannelMonitorUserHandler 创建 handler。
// settingService 用于每次请求前读取功能开关;关闭时 List/GetStatus 直接返回空/404。
func NewChannelMonitorUserHandler(
monitorService *service.ChannelMonitorService,
settingService *service.SettingService,
) *ChannelMonitorUserHandler {
return &ChannelMonitorUserHandler{
monitorService: monitorService,
settingService: settingService,
}
}
// featureEnabled 返回当前渠道监控功能是否开启。
// settingService 为 nil(测试场景)视为启用。
func (h *ChannelMonitorUserHandler) featureEnabled(c *gin.Context) bool {
if h.settingService == nil {
return true
}
return h.settingService.GetChannelMonitorRuntime(c.Request.Context()).Enabled
}
// --- Response ---
type channelMonitorUserListItem struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
GroupName string `json:"group_name"`
PrimaryModel string `json:"primary_model"`
PrimaryStatus string `json:"primary_status"`
PrimaryLatencyMs *int `json:"primary_latency_ms"`
PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"`
Availability7d float64 `json:"availability_7d"`
ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"`
Timeline []channelMonitorUserTimelinePoint `json:"timeline"`
}
// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。
// 仅用于用户视图 list 响应,admin 视图不使用。
type channelMonitorUserTimelinePoint struct {
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
PingLatencyMs *int `json:"ping_latency_ms"`
CheckedAt string `json:"checked_at"`
}
type channelMonitorUserDetailResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
GroupName string `json:"group_name"`
Models []channelMonitorUserModelStat `json:"models"`
}
type channelMonitorUserModelStat struct {
Model string `json:"model"`
LatestStatus string `json:"latest_status"`
LatestLatencyMs *int `json:"latest_latency_ms"`
Availability7d float64 `json:"availability_7d"`
Availability15d float64 `json:"availability_15d"`
Availability30d float64 `json:"availability_30d"`
AvgLatency7dMs *int `json:"avg_latency_7d_ms"`
}
func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListItem {
extras := make([]dto.ChannelMonitorExtraModelStatus, 0, len(v.ExtraModels))
for _, e := range v.ExtraModels {
extras = append(extras, dto.ChannelMonitorExtraModelStatus{
Model: e.Model,
Status: e.Status,
LatencyMs: e.LatencyMs,
})
}
timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline))
for _, p := range v.Timeline {
timeline = append(timeline, channelMonitorUserTimelinePoint{
Status: p.Status,
LatencyMs: p.LatencyMs,
PingLatencyMs: p.PingLatencyMs,
CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339),
})
}
return channelMonitorUserListItem{
ID: v.ID,
Name: v.Name,
Provider: v.Provider,
GroupName: v.GroupName,
PrimaryModel: v.PrimaryModel,
PrimaryStatus: v.PrimaryStatus,
PrimaryLatencyMs: v.PrimaryLatencyMs,
PrimaryPingLatencyMs: v.PrimaryPingLatencyMs,
Availability7d: v.Availability7d,
ExtraModels: extras,
Timeline: timeline,
}
}
func userMonitorDetailToResponse(d *service.UserMonitorDetail) *channelMonitorUserDetailResponse {
models := make([]channelMonitorUserModelStat, 0, len(d.Models))
for _, m := range d.Models {
models = append(models, channelMonitorUserModelStat{
Model: m.Model,
LatestStatus: m.LatestStatus,
LatestLatencyMs: m.LatestLatencyMs,
Availability7d: m.Availability7d,
Availability15d: m.Availability15d,
Availability30d: m.Availability30d,
AvgLatency7dMs: m.AvgLatency7dMs,
})
}
return &channelMonitorUserDetailResponse{
ID: d.ID,
Name: d.Name,
Provider: d.Provider,
GroupName: d.GroupName,
Models: models,
}
}
// --- Handlers ---
// List GET /api/v1/channel-monitors
func (h *ChannelMonitorUserHandler) List(c *gin.Context) {
if !h.featureEnabled(c) {
response.Success(c, gin.H{"items": []channelMonitorUserListItem{}})
return
}
views, err := h.monitorService.ListUserView(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
items := make([]channelMonitorUserListItem, 0, len(views))
for _, v := range views {
items = append(items, userMonitorViewToItem(v))
}
response.Success(c, gin.H{"items": items})
}
// GetStatus GET /api/v1/channel-monitors/:id/status
func (h *ChannelMonitorUserHandler) GetStatus(c *gin.Context) {
if !h.featureEnabled(c) {
response.ErrorFrom(c, service.ErrChannelMonitorNotFound)
return
}
// 复用 admin.ParseChannelMonitorID 保持错误码与日志一致。
id, ok := admin.ParseChannelMonitorID(c)
if !ok {
return
}
detail, err := h.monitorService.GetUserDetail(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, userMonitorDetailToResponse(detail))
}
package dto
// ChannelMonitorExtraModelStatus 渠道监控附加模型最近一次状态。
// 同时被 admin handler(List 响应)与 user handler(List 响应)复用,
// 字段必须保持一致以保证前端拿到统一结构。
type ChannelMonitorExtraModelStatus struct {
Model string `json:"model"`
Status string `json:"status"`
LatencyMs *int `json:"latency_ms"`
}
package dto
import (
"reflect"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TestPublicSettingsInjectionPayload_SchemaDoesNotDrift guarantees the SSR
// injection struct exposes every JSON field consumed by the frontend.
//
// Why this test exists: before we extracted a named PublicSettingsInjectionPayload
// type, the inline struct was manually kept in sync with dto.PublicSettings and
// drifted — ChannelMonitorEnabled / AvailableChannelsEnabled were missing, which
// made the frontend read `undefined` on refresh and hide the "可用渠道" menu
// until the async /api/v1/settings/public round-trip finished.
//
// This test compares the two JSON-tag sets and fails if injection is missing
// any field that dto.PublicSettings exposes. Adding a new feature flag with
// only a DTO entry will fail this test until the injection struct is updated.
//
// Intentional exclusions (fields present on dto.PublicSettings that SSR does
// not need to inject) are listed in `dtoOnlyFields` below with a reason.
func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) {
injection := jsonTags(reflect.TypeOf(service.PublicSettingsInjectionPayload{}))
dtoKeys := jsonTags(reflect.TypeOf(PublicSettings{}))
// Fields that legitimately live only on the DTO. Keep tiny; document each.
dtoOnlyFields := map[string]string{
// sora_client_enabled is an upstream-only field the fork does not surface.
"sora_client_enabled": "upstream-only field, not used on this fork",
// force_email_on_third_party_signup lives on the DTO but is not injected via SSR.
"force_email_on_third_party_signup": "auth-source default, not a feature flag",
}
var missing []string
for key := range dtoKeys {
if _, ok := injection[key]; ok {
continue
}
if _, allowed := dtoOnlyFields[key]; allowed {
continue
}
missing = append(missing, key)
}
if len(missing) > 0 {
t.Fatalf("service.PublicSettingsInjectionPayload is missing JSON fields present on dto.PublicSettings: %s\n"+
"add the field to PublicSettingsInjectionPayload (and GetPublicSettingsForInjection), or "+
"document the exclusion in dtoOnlyFields with a reason.", strings.Join(missing, ", "))
}
}
func jsonTags(t reflect.Type) map[string]struct{} {
out := make(map[string]struct{})
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("json")
if tag == "" || tag == "-" {
continue
}
name := strings.SplitN(tag, ",", 2)[0]
if name == "" {
continue
}
out[name] = struct{}{}
}
return out
}
......@@ -184,6 +184,13 @@ type SystemSettings struct {
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
// Channel Monitor feature switch
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
// Available Channels feature switch (user-facing aggregate view)
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
}
type DefaultSubscriptionSetting struct {
......@@ -231,6 +238,11 @@ type PublicSettings struct {
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
......
......@@ -304,6 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", apiKey.GroupID),
zap.String("platform", platform),
zap.Error(err),
)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
......@@ -347,6 +353,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
zap.String("platform", platform),
)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
......@@ -528,6 +539,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", currentAPIKey.GroupID),
zap.String("platform", platform),
zap.Bool("fallback_used", fallbackUsed),
zap.Error(err),
)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
......@@ -571,6 +589,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
zap.String("platform", platform),
)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
......
......@@ -6,50 +6,54 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
Backup *admin.BackupHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler
Ops *admin.OpsHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
Payment *admin.PaymentHandler
Dashboard *admin.DashboardHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
Backup *admin.BackupHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler
Ops *admin.OpsHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
APIKey *admin.AdminAPIKeyHandler
ScheduledTest *admin.ScheduledTestHandler
Channel *admin.ChannelHandler
ChannelMonitor *admin.ChannelMonitorHandler
ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
Payment *admin.PaymentHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
Payment *PaymentHandler
PaymentWebhook *PaymentWebhookHandler
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Announcement *AnnouncementHandler
ChannelMonitor *ChannelMonitorUserHandler
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
Payment *PaymentHandler
PaymentWebhook *PaymentWebhookHandler
AvailableChannel *AvailableChannelHandler
}
// BuildInfo contains build-time information
......
......@@ -2,6 +2,7 @@ package handler
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
......@@ -114,6 +115,20 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
}
if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
// Unknown order: ack with 2xx so the provider stops retrying. This
// guards against foreign environments whose webhook endpoints are
// (mis)configured to point at us — without a 2xx, the provider will
// retry for days and spam our error logs. We still emit a WARN so the
// event is discoverable in logs.
if errors.Is(err, service.ErrOrderNotFound) {
slog.Warn("[Payment Webhook] unknown order, acking to stop retries",
"provider", resolvedProviderKey,
"outTradeNo", notification.OrderID,
"tradeNo", notification.TradeNo,
)
writeSuccessResponse(c, resolvedProviderKey)
return
}
slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
......
......@@ -6,11 +6,13 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -91,6 +93,43 @@ func TestWriteSuccessResponse(t *testing.T) {
}
}
// TestUnknownOrderWebhookAcksWithSuccess exercises the response contract that
// handleNotify relies on when HandlePaymentNotification returns ErrOrderNotFound:
// we still need to emit the provider-specific 2xx so the provider stops
// retrying. We can't easily drive handleNotify end-to-end without mocking the
// concrete *service.PaymentService, so this test locks down the two ingredients
// the fix depends on:
// 1. errors.Is recognises the sentinel through fmt.Errorf %w wrapping (which
// is how service layer wraps it with the out_trade_no context).
// 2. writeSuccessResponse produces the provider-specific body for Stripe
// (empty 200) — matching what handleNotify calls on the ack path.
//
// If either contract breaks, the Stripe "unknown order → 500 loop" regresses.
func TestUnknownOrderWebhookAcksWithSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
// 1) Sentinel recognition through wrapping.
wrapped := fmt.Errorf("%w: out_trade_no=sub2_missing_42", service.ErrOrderNotFound)
require.True(t, errors.Is(wrapped, service.ErrOrderNotFound),
"handleNotify uses errors.Is on the wrapped service error; regression here "+
"would mean unknown-order webhooks go back to returning 500 and looping forever")
// A distinct error must NOT match — otherwise a DB failure would be silently
// swallowed as an ack.
other := errors.New("lookup order failed: connection refused")
require.False(t, errors.Is(other, service.ErrOrderNotFound))
// 2) Provider-specific success body is what handleNotify emits on the
// ack path. Asserted again here because this is the shape Stripe expects
// to consider the webhook acknowledged.
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
writeSuccessResponse(c, payment.TypeStripe)
require.Equal(t, http.StatusOK, w.Code,
"Stripe requires 2xx to stop retrying; anything else restarts the retry loop")
require.Empty(t, w.Body.String(), "Stripe expects an empty body on the ack path")
}
func TestWebhookConstants(t *testing.T) {
t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) {
assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment