Commit ebac0dc6 authored by erio's avatar erio
Browse files

feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制

- 缓存重构为 O(1) 哈希结构 (pricingByGroupModel, mappingByGroupModel)
- 渠道模型映射接入网关流程 (Forward 前应用, a→b→c 映射链)
- 新增 billing_model_source 配置 (请求模型/最终模型计费)
- usage_logs 新增 channel_id, model_mapping_chain, billing_tier 字段
- 每种计费模式统一支持默认价格 + 区间定价
- 渠道模型限制开关 (restrict_models)
- 分组按平台分类展示 + 彩色图标
- 必填字段红色星号 + 模型映射 UI
- 去除模型通配符支持
parent 29d58f24
...@@ -53,6 +53,9 @@ func (UsageLog) Fields() []ent.Field { ...@@ -53,6 +53,9 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(100). MaxLen(100).
Optional(). Optional().
Nillable(), Nillable(),
field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
field.Int64("group_id"). field.Int64("group_id").
Optional(). Optional().
Nillable(), Nillable(),
......
...@@ -24,31 +24,36 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { ...@@ -24,31 +24,36 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler {
// --- Request / Response types --- // --- Request / Response types ---
type createChannelRequest struct { type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"` Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"` Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"` ModelMapping map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels bool `json:"restrict_models"`
} }
type updateChannelRequest struct { type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"` Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"` Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"` ModelMapping map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels *bool `json:"restrict_models"`
} }
type channelModelPricingRequest struct { type channelModelPricingRequest struct {
Models []string `json:"models" binding:"required,min=1,max=100"` Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"` PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"`
} }
type pricingIntervalRequest struct { type pricingIntervalRequest struct {
...@@ -64,27 +69,30 @@ type pricingIntervalRequest struct { ...@@ -64,27 +69,30 @@ type pricingIntervalRequest struct {
} }
type channelResponse struct { type channelResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Status string `json:"status"` Status string `json:"status"`
GroupIDs []int64 `json:"group_ids"` BillingModelSource string `json:"billing_model_source"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"` RestrictModels bool `json:"restrict_models"`
ModelMapping map[string]string `json:"model_mapping"` GroupIDs []int64 `json:"group_ids"`
CreatedAt string `json:"created_at"` ModelPricing []channelModelPricingResponse `json:"model_pricing"`
UpdatedAt string `json:"updated_at"` ModelMapping map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
} }
type channelModelPricingResponse struct { type channelModelPricingResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Models []string `json:"models"` Models []string `json:"models"`
BillingMode string `json:"billing_mode"` BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"` InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"` OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"` CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"` CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"` ImageOutputPrice *float64 `json:"image_output_price"`
Intervals []pricingIntervalResponse `json:"intervals"` PerRequestPrice *float64 `json:"per_request_price"`
Intervals []pricingIntervalResponse `json:"intervals"`
} }
type pricingIntervalResponse struct { type pricingIntervalResponse struct {
...@@ -109,11 +117,16 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -109,11 +117,16 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Name: ch.Name, Name: ch.Name,
Description: ch.Description, Description: ch.Description,
Status: ch.Status, Status: ch.Status,
RestrictModels: ch.RestrictModels,
GroupIDs: ch.GroupIDs, GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping, ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
} }
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = "requested"
}
if resp.GroupIDs == nil { if resp.GroupIDs == nil {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
} }
...@@ -155,6 +168,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -155,6 +168,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
CacheWritePrice: p.CacheWritePrice, CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice, CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice, ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals, Intervals: intervals,
}) })
} }
...@@ -190,6 +204,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -190,6 +204,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
CacheWritePrice: r.CacheWritePrice, CacheWritePrice: r.CacheWritePrice,
CacheReadPrice: r.CacheReadPrice, CacheReadPrice: r.CacheReadPrice,
ImageOutputPrice: r.ImageOutputPrice, ImageOutputPrice: r.ImageOutputPrice,
PerRequestPrice: r.PerRequestPrice,
Intervals: intervals, Intervals: intervals,
}) })
} }
...@@ -249,11 +264,13 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -249,11 +264,13 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelPricing: pricingRequestToService(req.ModelPricing), ModelPricing: pricingRequestToService(req.ModelPricing),
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
...@@ -279,11 +296,13 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -279,11 +296,13 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
input := &service.UpdateChannelInput{ input := &service.UpdateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Status: req.Status, Status: req.Status,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
......
...@@ -604,6 +604,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { ...@@ -604,6 +604,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
return &AdminUsageLog{ return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l), UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel, UpstreamModel: l.UpstreamModel,
ChannelID: l.ChannelID,
ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),
......
...@@ -406,6 +406,13 @@ type AdminUsageLog struct { ...@@ -406,6 +406,13 @@ type AdminUsageLog struct {
// Omitted when no mapping was applied (requested model was used as-is). // Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"` UpstreamModel *string `json:"upstream_model,omitempty"`
// ChannelID 渠道 ID
ChannelID *int64 `json:"channel_id,omitempty"`
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string `json:"billing_tier,omitempty"`
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
......
...@@ -158,6 +158,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -158,6 +158,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 解析渠道级模型映射
var channelMapping service.ChannelMappingResult
if apiKey.GroupID != nil {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
if apiKey.GroupID != nil {
checkModel := reqModel
if channelMapping.Mapped {
checkModel = channelMapping.MappedModel
}
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) {
h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel)
return
}
}
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
...@@ -478,6 +496,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -478,6 +496,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: func() string {
if !channelMapping.Mapped {
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
return reqModel + "→" + result.UpstreamModel
}
return ""
}
if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel {
return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel
}
return reqModel + "→" + channelMapping.MappedModel
}(),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),
...@@ -660,6 +693,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -660,6 +693,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
parsedReq.OnUpstreamAccepted = queueRelease parsedReq.OnUpstreamAccepted = queueRelease
// ===== 用户消息串行队列 END ===== // ===== 用户消息串行队列 END =====
// 应用渠道模型映射到请求
if channelMapping.Mapped {
parsedReq.Model = channelMapping.MappedModel
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
...@@ -810,6 +850,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -810,6 +850,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: func() string {
if !channelMapping.Mapped {
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
return reqModel + "→" + result.UpstreamModel
}
return ""
}
if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel {
return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel
}
return reqModel + "→" + channelMapping.MappedModel
}(),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),
......
...@@ -42,9 +42,9 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel ...@@ -42,9 +42,9 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
return err return err
} }
err = tx.QueryRowContext(ctx, err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping) VALUES ($1, $2, $3, $4) `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`, RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil { if err != nil {
if isUniqueViolation(err) { if isUniqueViolation(err) {
...@@ -75,9 +75,9 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha ...@@ -75,9 +75,9 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
ch := &service.Channel{} ch := &service.Channel{}
var modelMappingJSON []byte var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx, err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, created_at, updated_at `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
FROM channels WHERE id = $1`, id, FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt) ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound return nil, service.ErrChannelNotFound
} }
...@@ -108,9 +108,9 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel ...@@ -108,9 +108,9 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
return err return err
} }
result, err := tx.ExecContext(ctx, result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, updated_at = NOW() `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $5`, WHERE id = $7`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.ID, channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
) )
if err != nil { if err != nil {
if isUniqueViolation(err) { if isUniqueViolation(err) {
...@@ -187,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati ...@@ -187,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表 // 查询 channel 列表
dataQuery := fmt.Sprintf( dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.created_at, c.updated_at `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`, FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1, whereClause, argIdx, argIdx+1,
) )
...@@ -204,7 +204,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati ...@@ -204,7 +204,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
for rows.Next() { for rows.Next() {
var ch service.Channel var ch service.Channel
var modelMappingJSON []byte var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err) return nil, nil, fmt.Errorf("scan channel: %w", err)
} }
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
...@@ -248,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati ...@@ -248,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels ORDER BY id`, `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("query all channels: %w", err) return nil, fmt.Errorf("query all channels: %w", err)
...@@ -260,7 +260,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err ...@@ -260,7 +260,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
for rows.Next() { for rows.Next() {
var ch service.Channel var ch service.Channel
var modelMappingJSON []byte var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err) return nil, fmt.Errorf("scan channel: %w", err)
} }
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
) )
if err != nil { if err != nil {
...@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser ...@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser
} }
result, err := r.db.ExecContext(ctx, result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing `UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW() SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, updated_at = NOW()
WHERE id = $8`, WHERE id = $9`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.ID, pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update model pricing: %w", err) return fmt.Errorf("update model pricing: %w", err)
...@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i ...@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) // batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs), pq.Array(channelIDs),
) )
...@@ -171,7 +171,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6 ...@@ -171,7 +171,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6
if err := rows.Scan( if err := rows.Scan(
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode, &p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt, &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil { ); err != nil {
return nil, nil, fmt.Errorf("scan model pricing: %w", err) return nil, nil, fmt.Errorf("scan model pricing: %w", err)
} }
...@@ -224,11 +224,11 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C ...@@ -224,11 +224,11 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
err = exec.QueryRowContext(ctx, err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price) `INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`, VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`,
pricing.ChannelID, modelsJSON, billingMode, pricing.ChannelID, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil { if err != nil {
return fmt.Errorf("insert model pricing: %w", err) return fmt.Errorf("insert model pricing: %w", err)
......
...@@ -28,7 +28,7 @@ import ( ...@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
...@@ -77,6 +77,9 @@ var usageLogInsertArgTypes = [...]string{ ...@@ -77,6 +77,9 @@ var usageLogInsertArgTypes = [...]string{
"text", // inbound_endpoint "text", // inbound_endpoint
"text", // upstream_endpoint "text", // upstream_endpoint
"boolean", // cache_ttl_overridden "boolean", // cache_ttl_overridden
"bigint", // channel_id
"text", // model_mapping_chain
"text", // billing_tier
"timestamptz", // created_at "timestamptz", // created_at
} }
...@@ -350,6 +353,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -350,6 +353,9 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
...@@ -357,7 +363,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ...@@ -357,7 +363,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15,
$16, $17, $18, $19, $20, $21, $16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
...@@ -782,10 +788,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -782,10 +788,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(keys)*39) args := make([]any, 0, len(keys)*44)
argPos := 1 argPos := 1
for idx, key := range keys { for idx, key := range keys {
if idx > 0 { if idx > 0 {
...@@ -853,6 +862,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -853,6 +862,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) )
SELECT SELECT
...@@ -895,6 +907,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ...@@ -895,6 +907,9 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
...@@ -977,10 +992,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -977,10 +992,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*40) args := make([]any, 0, len(preparedList)*43)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
...@@ -1045,6 +1063,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1045,6 +1063,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) )
SELECT SELECT
...@@ -1087,6 +1108,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ...@@ -1087,6 +1108,9 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
...@@ -1137,6 +1161,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1137,6 +1161,9 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
inbound_endpoint, inbound_endpoint,
upstream_endpoint, upstream_endpoint,
cache_ttl_overridden, cache_ttl_overridden,
channel_id,
model_mapping_chain,
billing_tier,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
...@@ -1144,7 +1171,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ...@@ -1144,7 +1171,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $14, $15,
$16, $17, $18, $19, $20, $21, $16, $17, $18, $19, $20, $21,
$22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
...@@ -1176,6 +1203,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1176,6 +1203,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort) reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint) inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint)
channelID := nullInt64(log.ChannelID)
modelMappingChain := nullString(log.ModelMappingChain)
billingTier := nullString(log.BillingTier)
requestedModel := strings.TrimSpace(log.RequestedModel) requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" { if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model) requestedModel = strings.TrimSpace(log.Model)
...@@ -1232,6 +1262,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ...@@ -1232,6 +1262,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
inboundEndpoint, inboundEndpoint,
upstreamEndpoint, upstreamEndpoint,
log.CacheTTLOverridden, log.CacheTTLOverridden,
channelID,
modelMappingChain,
billingTier,
createdAt, createdAt,
}, },
} }
...@@ -3959,6 +3992,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -3959,6 +3992,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
inboundEndpoint sql.NullString inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString upstreamEndpoint sql.NullString
cacheTTLOverridden bool cacheTTLOverridden bool
channelID sql.NullInt64
modelMappingChain sql.NullString
billingTier sql.NullString
createdAt time.Time createdAt time.Time
) )
...@@ -4003,6 +4039,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4003,6 +4039,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&inboundEndpoint, &inboundEndpoint,
&upstreamEndpoint, &upstreamEndpoint,
&cacheTTLOverridden, &cacheTTLOverridden,
&channelID,
&modelMappingChain,
&billingTier,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
...@@ -4087,6 +4126,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ...@@ -4087,6 +4126,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid { if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String log.UpstreamModel = &upstreamModel.String
} }
if channelID.Valid {
value := channelID.Int64
log.ChannelID = &value
}
if modelMappingChain.Valid {
log.ModelMappingChain = &modelMappingChain.String
}
if billingTier.Valid {
log.BillingTier = &billingTier.String
}
return log, nil return log, nil
} }
......
...@@ -23,14 +23,21 @@ func (m BillingMode) IsValid() bool { ...@@ -23,14 +23,21 @@ func (m BillingMode) IsValid() bool {
return false return false
} }
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
)
// Channel 渠道实体 // Channel 渠道实体
type Channel struct { type Channel struct {
ID int64 ID int64
Name string Name string
Description string Description string
Status string Status string
CreatedAt time.Time BillingModelSource string // "requested" or "upstream"
UpdatedAt time.Time RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表 // 关联的分组 ID 列表
GroupIDs []int64 GroupIDs []int64
...@@ -44,13 +51,14 @@ type Channel struct { ...@@ -44,13 +51,14 @@ type Channel struct {
type ChannelModelPricing struct { type ChannelModelPricing struct {
ID int64 ID int64
ChannelID int64 ChannelID int64
Models []string // 绑定的模型列表 Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式 BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格(USD) OutputPrice *float64 // 每 token 输出价格(USD)
CacheWritePrice *float64 // 缓存写入价格 CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格 CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容) ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格(USD)
Intervals []PricingInterval // 区间定价列表 Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
...@@ -106,12 +114,10 @@ func (c *Channel) IsActive() bool { ...@@ -106,12 +114,10 @@ func (c *Channel) IsActive() bool {
} }
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 // GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。 // 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
// 返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model) modelLower := strings.ToLower(model)
// 第一轮:精确匹配
for i := range c.ModelPricing { for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models { for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower { if strings.ToLower(m) == modelLower {
...@@ -121,20 +127,6 @@ func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { ...@@ -121,20 +127,6 @@ func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
} }
} }
// 第二轮:通配符匹配(仅支持末尾 *)
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
mLower := strings.ToLower(m)
if strings.HasSuffix(mLower, "*") {
prefix := strings.TrimSuffix(mLower, "*")
if strings.HasPrefix(modelLower, prefix) {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
}
return nil return nil
} }
......
...@@ -47,13 +47,30 @@ type ChannelRepository interface { ...@@ -47,13 +47,30 @@ type ChannelRepository interface {
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
} }
// channelCache 渠道缓存快照 // channelModelKey 渠道缓存复合键
type channelModelKey struct {
groupID int64
model string // lowercase
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct { type channelCache struct {
// byID: channelID -> *Channel(含 ModelPricing) // 热路径查找
byID map[int64]*Channel pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价
// byGroupID: groupID -> channelID mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标
byGroupID map[int64]int64 channelByGroupID map[int64]*Channel // groupID → 渠道
loadedAt time.Time
// 冷路径(CRUD 操作)
byID map[int64]*Channel
loadedAt time.Time
}
// ChannelMappingResult 渠道映射查找结果
type ChannelMappingResult struct {
MappedModel string // 映射后的模型名(无映射时等于原始模型名)
ChannelID int64 // 渠道 ID(0 = 无渠道关联)
Mapped bool // 是否发生了映射
BillingModelSource string // 计费模型来源("requested" / "upstream")
} }
const ( const (
...@@ -115,25 +132,46 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -115,25 +132,46 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err) slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{ errorCache := &channelCache{
byID: make(map[int64]*Channel), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
byGroupID: make(map[int64]int64), mappingByGroupModel: make(map[channelModelKey]string),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
} }
s.cache.Store(errorCache) s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
} }
cache := &channelCache{ cache := &channelCache{
byID: make(map[int64]*Channel, len(channels)), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
byGroupID: make(map[int64]int64), mappingByGroupModel: make(map[channelModelKey]string),
loadedAt: time.Now(), channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
} }
for i := range channels { for i := range channels {
ch := &channels[i] ch := &channels[i]
cache.byID[ch.ID] = ch cache.byID[ch.ID] = ch
// 展开到分组维度
for _, gid := range ch.GroupIDs { for _, gid := range ch.GroupIDs {
cache.byGroupID[gid] = ch.ID cache.channelByGroupID[gid] = ch
// 展开模型定价到 (groupID, model) → *ChannelModelPricing
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
// 展开模型映射到 (groupID, model) → target
for src, dst := range ch.ModelMapping {
key := channelModelKey{groupID: gid, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
} }
} }
...@@ -147,42 +185,94 @@ func (s *ChannelService) invalidateCache() { ...@@ -147,42 +185,94 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache") s.cacheSF.Forget("channel_cache")
} }
// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取) // GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
// 返回深拷贝,不污染缓存。
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx) cache, err := s.loadCache(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
channelID, ok := cache.byGroupID[groupID] ch, ok := cache.channelByGroupID[groupID]
if !ok { if !ok || !ch.IsActive() {
return nil, nil return nil, nil
} }
ch, ok := cache.byID[channelID] return ch.Clone(), nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
cache, err := s.loadCache(ctx)
if err != nil {
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
return nil
}
// 检查渠道是否启用
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil
}
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key]
if !ok { if !ok {
return nil, nil return nil
} }
if !ch.IsActive() { cp := pricing.Clone()
return nil, nil return &cp
}
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
cache, err := s.loadCache(ctx)
if err != nil {
return ChannelMappingResult{MappedModel: model}
} }
return ch.Clone(), nil ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return ChannelMappingResult{MappedModel: model}
}
result := ChannelMappingResult{
MappedModel: model,
ChannelID: ch.ID,
BillingModelSource: ch.BillingModelSource,
}
if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceRequested
}
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped
result.Mapped = true
}
return result
} }
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径) // IsModelRestricted 检查模型是否被渠道限制。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { // 返回 true 表示模型被限制(不在允许列表中)。
ch, err := s.GetChannelForGroup(ctx, groupID) // 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
cache, err := s.loadCache(ctx)
if err != nil { if err != nil {
slog.Warn("failed to get channel for group", "group_id", groupID, "error", err) return false // 缓存加载失败时不限制
return nil
} }
if ch == nil {
return nil ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() || !ch.RestrictModels {
return false
} }
return ch.GetModelPricing(model)
// 检查模型是否在定价列表中
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key]
return !exists
} }
// --- CRUD --- // --- CRUD ---
...@@ -209,12 +299,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -209,12 +299,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
} }
channel := &Channel{ channel := &Channel{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Status: StatusActive, Status: StatusActive,
GroupIDs: input.GroupIDs, BillingModelSource: input.BillingModelSource,
ModelPricing: input.ModelPricing, RestrictModels: input.RestrictModels,
ModelMapping: input.ModelMapping, GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested
} }
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
...@@ -260,6 +355,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -260,6 +355,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.Status = input.Status channel.Status = input.Status
} }
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突 // 检查分组冲突
if input.GroupIDs != nil { if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
...@@ -280,6 +379,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -280,6 +379,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.ModelMapping = input.ModelMapping channel.ModelMapping = input.ModelMapping
} }
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
return nil, err return nil, err
} }
...@@ -351,19 +454,23 @@ func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { ...@@ -351,19 +454,23 @@ func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
// CreateChannelInput 创建渠道输入 // CreateChannelInput 创建渠道输入
type CreateChannelInput struct { type CreateChannelInput struct {
Name string Name string
Description string Description string
GroupIDs []int64 GroupIDs []int64
ModelPricing []ChannelModelPricing ModelPricing []ChannelModelPricing
ModelMapping map[string]string ModelMapping map[string]string
BillingModelSource string
RestrictModels bool
} }
// UpdateChannelInput 更新渠道输入 // UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct { type UpdateChannelInput struct {
Name string Name string
Description *string Description *string
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string ModelMapping map[string]string
BillingModelSource string
RestrictModels *bool
} }
...@@ -15,7 +15,6 @@ func TestGetModelPricing(t *testing.T) { ...@@ -15,7 +15,6 @@ func TestGetModelPricing(t *testing.T) {
ch := &Channel{ ch := &Channel{
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)}, {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
}, },
} }
...@@ -28,9 +27,8 @@ func TestGetModelPricing(t *testing.T) { ...@@ -28,9 +27,8 @@ func TestGetModelPricing(t *testing.T) {
}{ }{
{"exact match", "claude-sonnet-4", 1, false}, {"exact match", "claude-sonnet-4", 1, false},
{"case insensitive", "Claude-Sonnet-4", 1, false}, {"case insensitive", "Claude-Sonnet-4", 1, false},
{"wildcard match", "claude-opus-4-20250514", 2, false},
{"exact takes priority over wildcard", "claude-sonnet-4", 1, false},
{"not found", "gemini-3.1-pro", 0, true}, {"not found", "gemini-3.1-pro", 0, true},
{"wildcard pattern not matched", "claude-opus-4-20250514", 0, true},
{"per_request model", "gpt-5.1", 3, false}, {"per_request model", "gpt-5.1", 3, false},
} }
......
...@@ -7413,6 +7413,12 @@ type RecordUsageInput struct { ...@@ -7413,6 +7413,12 @@ type RecordUsageInput struct {
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
} }
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
...@@ -7732,7 +7738,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7732,7 +7738,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" { if result.MediaType == "image" || result.MediaType == "video" {
...@@ -7815,7 +7831,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7815,7 +7831,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
...@@ -7842,6 +7858,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7842,6 +7858,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
ImageSize: imageSize, ImageSize: imageSize,
MediaType: mediaType, MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
...@@ -7909,6 +7927,12 @@ type RecordUsageLongContextInput struct { ...@@ -7909,6 +7927,12 @@ type RecordUsageLongContextInput struct {
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
// 渠道映射信息(由 handler 在 Forward 前解析)
ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
} }
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...@@ -7946,7 +7970,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -7946,7 +7970,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} }
var cost *CostBreakdown var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.ImageCount > 0 { if result.ImageCount > 0 {
...@@ -8008,7 +8042,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8008,7 +8042,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
...@@ -8034,6 +8068,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8034,6 +8068,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: imageSize,
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
...@@ -8085,6 +8121,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8085,6 +8121,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
return nil return nil
} }
// ResolveChannelMapping 委托渠道服务解析模型映射
func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return s.replaceModelInBody(body, newModel)
}
// IsModelRestricted 检查模型是否被渠道限制
func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......
...@@ -104,6 +104,12 @@ type UsageLog struct { ...@@ -104,6 +104,12 @@ type UsageLog struct {
// UpstreamModel is the actual model sent to the upstream provider after mapping. // UpstreamModel is the actual model sent to the upstream provider after mapping.
// Nil means no mapping was applied (requested model was used as-is). // Nil means no mapping was applied (requested model was used as-is).
UpstreamModel *string UpstreamModel *string
// ChannelID 渠道 ID
ChannelID *int64
// ModelMappingChain 模型映射链,如 "a→b→c"
ModelMappingChain *string
// BillingTier 计费层级标签(per_request/image 模式)
BillingTier *string
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string ServiceTier *string
// ReasoningEffort is the request's reasoning effort level. // ReasoningEffort is the request's reasoning effort level.
......
...@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string { ...@@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string {
} }
return strings.TrimSpace(upstreamModel) return strings.TrimSpace(upstreamModel)
} }
func optionalInt64Ptr(v int64) *int64 {
if v == 0 {
return nil
}
return &v
}
-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
-- Add channel tracking fields to usage_logs
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
-- Add model restriction switch to channels
ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
...@@ -29,6 +29,7 @@ export interface ChannelModelPricing { ...@@ -29,6 +29,7 @@ export interface ChannelModelPricing {
cache_write_price: number | null cache_write_price: number | null
cache_read_price: number | null cache_read_price: number | null
image_output_price: number | null image_output_price: number | null
per_request_price: number | null
intervals: PricingInterval[] intervals: PricingInterval[]
} }
...@@ -37,8 +38,11 @@ export interface Channel { ...@@ -37,8 +38,11 @@ export interface Channel {
name: string name: string
description: string description: string
status: string status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
group_ids: number[] group_ids: number[]
model_pricing: ChannelModelPricing[] model_pricing: ChannelModelPricing[]
model_mapping: Record<string, string>
created_at: string created_at: string
updated_at: string updated_at: string
} }
...@@ -48,6 +52,9 @@ export interface CreateChannelRequest { ...@@ -48,6 +52,9 @@ export interface CreateChannelRequest {
description?: string description?: string
group_ids?: number[] group_ids?: number[]
model_pricing?: ChannelModelPricing[] model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, string>
billing_model_source?: string
restrict_models?: boolean
} }
export interface UpdateChannelRequest { export interface UpdateChannelRequest {
...@@ -56,6 +63,9 @@ export interface UpdateChannelRequest { ...@@ -56,6 +63,9 @@ export interface UpdateChannelRequest {
status?: string status?: string
group_ids?: number[] group_ids?: number[]
model_pricing?: ChannelModelPricing[] model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, string>
billing_model_source?: string
restrict_models?: boolean
} }
interface PaginatedResponse<T> { interface PaginatedResponse<T> {
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
/> />
</div> </div>
<p class="mt-1 text-xs text-gray-400"> <p class="mt-1 text-xs text-gray-400">
{{ t('admin.channels.form.modelInputHint', 'Press Enter to add. Supports wildcard *.') }} {{ t('admin.channels.form.modelInputHint', 'Press Enter to add, supports paste for batch import.') }}
</p> </p>
</div> </div>
</template> </template>
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
<div class="mt-3 flex items-start gap-2"> <div class="mt-3 flex items-start gap-2">
<div class="flex-1"> <div class="flex-1">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400"> <label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.models', '模型列表') }} {{ t('admin.channels.form.models', '模型列表') }} <span class="text-red-500">*</span>
</label> </label>
<ModelTagInput <ModelTagInput
:models="entry.models" :models="entry.models"
...@@ -153,6 +153,17 @@ ...@@ -153,6 +153,17 @@
<!-- Per-request mode --> <!-- Per-request mode -->
<div v-else-if="entry.billing_mode === 'per_request'"> <div v-else-if="entry.billing_mode === 'per_request'">
<!-- Default per-request price -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultPerRequestPrice', '默认单次价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.per_request_price" @input="emitField('per_request_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Tiers -->
<div class="mt-3 flex items-center justify-between"> <div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400"> <label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.requestTiers', '按次计费层级') }} {{ t('admin.channels.form.requestTiers', '按次计费层级') }}
...@@ -176,8 +187,19 @@ ...@@ -176,8 +187,19 @@
</div> </div>
</div> </div>
<!-- Image mode (legacy per-request) --> <!-- Image mode -->
<div v-else-if="entry.billing_mode === 'image'"> <div v-else-if="entry.billing_mode === 'image'">
<!-- Default image price -->
<label class="mt-3 block text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.defaultImagePrice', '默认图片价格(未命中层级时使用)') }}
<span class="ml-1 font-normal text-gray-400">$</span>
</label>
<div class="mt-1 w-48">
<input :value="entry.image_output_price" @input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
<!-- Image tiers -->
<div class="mt-3 flex items-center justify-between"> <div class="mt-3 flex items-center justify-between">
<label class="text-xs font-medium text-gray-500 dark:text-gray-400"> <label class="text-xs font-medium text-gray-500 dark:text-gray-400">
{{ t('admin.channels.form.imageTiers', '图片计费层级(按次)') }} {{ t('admin.channels.form.imageTiers', '图片计费层级(按次)') }}
...@@ -196,15 +218,6 @@ ...@@ -196,15 +218,6 @@
@remove="removeInterval(idx)" @remove="removeInterval(idx)"
/> />
</div> </div>
<div v-else>
<div class="mt-2 grid grid-cols-2 gap-2 sm:grid-cols-4">
<div>
<label class="text-xs text-gray-400">{{ t('admin.channels.form.imageOutputPrice', '图片输出价格') }}</label>
<input :value="entry.image_output_price" @input="emitField('image_output_price', ($event.target as HTMLInputElement).value)"
type="number" step="any" min="0" class="input mt-0.5 text-sm" :placeholder="t('admin.channels.form.pricePlaceholder', '默认')" />
</div>
</div>
</div>
</div> </div>
</div> </div>
</div> </div>
......
...@@ -20,6 +20,7 @@ export interface PricingFormEntry { ...@@ -20,6 +20,7 @@ export interface PricingFormEntry {
cache_write_price: number | string | null cache_write_price: number | string | null
cache_read_price: number | string | null cache_read_price: number | string | null
image_output_price: number | string | null image_output_price: number | string | null
per_request_price: number | string | null
intervals: IntervalFormEntry[] intervals: IntervalFormEntry[]
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment