"backend/vscode:/vscode.git/clone" did not exist on "bfcc562c35043f48aef0d83f4e8734ec231f1a59"
Commit 0b1ce6be authored by erio's avatar erio
Browse files

feat(channel): 缓存扁平化 + 网关映射集成 + 计费模式统一 + 模型限制

- 缓存按 (groupID, platform, model) 三维 key 扁平化,避免跨平台同名模型冲突
- buildCache 批量查询 group platform,按平台过滤展开定价和映射
- model_mapping 改为嵌套格式 {platform: {src: dst}}
- channel_model_pricing 新增 platform 列
- 前端按平台维度重构:每个平台独立配置分组/映射/定价
- 迁移 086: platform 列 + model_mapping 嵌套格式迁移
parent 28a6adaa
...@@ -28,7 +28,7 @@ type createChannelRequest struct { ...@@ -28,7 +28,7 @@ type createChannelRequest struct {
Description string `json:"description"` Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
} }
...@@ -39,12 +39,13 @@ type updateChannelRequest struct { ...@@ -39,12 +39,13 @@ type updateChannelRequest struct {
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
} }
type channelModelPricingRequest struct { type channelModelPricingRequest struct {
Platform string `json:"platform" binding:"omitempty,max=50"`
Models []string `json:"models" binding:"required,min=1,max=100"` Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
...@@ -77,13 +78,14 @@ type channelResponse struct { ...@@ -77,13 +78,14 @@ type channelResponse struct {
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
} }
type channelModelPricingResponse struct { type channelModelPricingResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Platform string `json:"platform"`
Models []string `json:"models"` Models []string `json:"models"`
BillingMode string `json:"billing_mode"` BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"` InputPrice *float64 `json:"input_price"`
...@@ -131,7 +133,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -131,7 +133,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
} }
if resp.ModelMapping == nil { if resp.ModelMapping == nil {
resp.ModelMapping = map[string]string{} resp.ModelMapping = map[string]map[string]string{}
} }
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing)) resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
...@@ -144,6 +146,10 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -144,6 +146,10 @@ func channelToResponse(ch *service.Channel) *channelResponse {
if billingMode == "" { if billingMode == "" {
billingMode = "token" billingMode = "token"
} }
platform := p.Platform
if platform == "" {
platform = "anthropic"
}
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals { for _, iv := range p.Intervals {
intervals = append(intervals, pricingIntervalResponse{ intervals = append(intervals, pricingIntervalResponse{
...@@ -161,6 +167,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -161,6 +167,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
resp.ModelPricing = append(resp.ModelPricing, channelModelPricingResponse{ resp.ModelPricing = append(resp.ModelPricing, channelModelPricingResponse{
ID: p.ID, ID: p.ID,
Platform: platform,
Models: models, Models: models,
BillingMode: billingMode, BillingMode: billingMode,
InputPrice: p.InputPrice, InputPrice: p.InputPrice,
...@@ -182,6 +189,10 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -182,6 +189,10 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
if billingMode == "" { if billingMode == "" {
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
platform := r.Platform
if platform == "" {
platform = "anthropic"
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{ intervals = append(intervals, service.PricingInterval{
...@@ -197,6 +208,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -197,6 +208,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
}) })
} }
result = append(result, service.ChannelModelPricing{ result = append(result, service.ChannelModelPricing{
Platform: platform,
Models: r.Models, Models: r.Models,
BillingMode: billingMode, BillingMode: billingMode,
InputPrice: r.InputPrice, InputPrice: r.InputPrice,
......
...@@ -406,8 +406,9 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe ...@@ -406,8 +406,9 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe
return conflicting, nil return conflicting, nil
} }
// marshalModelMapping 将 model mapping 序列化为 JSON 字节,nil/空 map 返回 '{}' // marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
func marshalModelMapping(m map[string]string) ([]byte, error) { // 格式:{"platform": {"src": "dst"}, ...}
func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
if len(m) == 0 { if len(m) == 0 {
return []byte("{}"), nil return []byte("{}"), nil
} }
...@@ -418,14 +419,43 @@ func marshalModelMapping(m map[string]string) ([]byte, error) { ...@@ -418,14 +419,43 @@ func marshalModelMapping(m map[string]string) ([]byte, error) {
return data, nil return data, nil
} }
// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping // unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
func unmarshalModelMapping(data []byte) map[string]string { func unmarshalModelMapping(data []byte) map[string]map[string]string {
if len(data) == 0 { if len(data) == 0 {
return nil return nil
} }
var m map[string]string var m map[string]map[string]string
if err := json.Unmarshal(data, &m); err != nil { if err := json.Unmarshal(data, &m); err != nil {
return nil return nil
} }
return m return m
} }
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
return make(map[int64]string), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, platform FROM groups WHERE id = ANY($1)`,
pq.Array(groupIDs),
)
if err != nil {
return nil, fmt.Errorf("get group platforms: %w", err)
}
defer rows.Close()
result := make(map[int64]string, len(groupIDs))
for rows.Next() {
var id int64
var platform string
if err := rows.Scan(&id, &platform); err != nil {
return nil, fmt.Errorf("scan group platform: %w", err)
}
result[id] = platform
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group platforms: %w", err)
}
return result, nil
}
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
) )
if err != nil { if err != nil {
...@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser ...@@ -56,10 +56,10 @@ func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *ser
} }
result, err := r.db.ExecContext(ctx, result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing `UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, updated_at = NOW() SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
WHERE id = $9`, WHERE id = $10`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ID, pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update model pricing: %w", err) return fmt.Errorf("update model pricing: %w", err)
...@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i ...@@ -90,7 +90,7 @@ func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID i
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) // batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs), pq.Array(channelIDs),
) )
...@@ -169,7 +169,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6 ...@@ -169,7 +169,7 @@ func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int6
var p service.ChannelModelPricing var p service.ChannelModelPricing
var modelsJSON []byte var modelsJSON []byte
if err := rows.Scan( if err := rows.Scan(
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode, &p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil { ); err != nil {
...@@ -223,10 +223,14 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C ...@@ -223,10 +223,14 @@ func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.C
if billingMode == "" { if billingMode == "" {
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
platform := pricing.Platform
if platform == "" {
platform = "anthropic"
}
err = exec.QueryRowContext(ctx, err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) `INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
pricing.ChannelID, modelsJSON, billingMode, pricing.ChannelID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
......
...@@ -41,16 +41,17 @@ type Channel struct { ...@@ -41,16 +41,17 @@ type Channel struct {
// 关联的分组 ID 列表 // 关联的分组 ID 列表
GroupIDs []int64 GroupIDs []int64
// 模型定价列表 // 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing ModelPricing []ChannelModelPricing
// 渠道级模型映射 // 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping map[string]string ModelMapping map[string]map[string]string
} }
// ChannelModelPricing 渠道模型定价条目 // ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct { type ChannelModelPricing struct {
ID int64 ID int64
ChannelID int64 ChannelID int64
Platform string // 所属平台(anthropic/openai/gemini/...)
Models []string // 绑定的模型列表 Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式 BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
...@@ -82,21 +83,26 @@ type PricingInterval struct { ...@@ -82,21 +83,26 @@ type PricingInterval struct {
} }
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。 // ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
// platform 指定查找哪个平台的映射规则。
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。 // 支持通配符(如 "claude-*" → "claude-sonnet-4")。
// 如果没有匹配的映射规则,返回原始模型名。 // 如果没有匹配的映射规则,返回原始模型名。
func (c *Channel) ResolveMappedModel(requestedModel string) string { func (c *Channel) ResolveMappedModel(platform, requestedModel string) string {
if len(c.ModelMapping) == 0 { if len(c.ModelMapping) == 0 {
return requestedModel return requestedModel
} }
platformMapping, ok := c.ModelMapping[platform]
if !ok || len(platformMapping) == 0 {
return requestedModel
}
lower := strings.ToLower(requestedModel) lower := strings.ToLower(requestedModel)
// 精确匹配优先 // 精确匹配优先
for src, dst := range c.ModelMapping { for src, dst := range platformMapping {
if strings.ToLower(src) == lower { if strings.ToLower(src) == lower {
return dst return dst
} }
} }
// 通配符匹配 // 通配符匹配
for src, dst := range c.ModelMapping { for src, dst := range platformMapping {
srcLower := strings.ToLower(src) srcLower := strings.ToLower(src)
if strings.HasSuffix(srcLower, "*") { if strings.HasSuffix(srcLower, "*") {
prefix := strings.TrimSuffix(srcLower, "*") prefix := strings.TrimSuffix(srcLower, "*")
...@@ -190,9 +196,13 @@ func (c *Channel) Clone() *Channel { ...@@ -190,9 +196,13 @@ func (c *Channel) Clone() *Channel {
} }
} }
if c.ModelMapping != nil { if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]string, len(c.ModelMapping)) cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for k, v := range c.ModelMapping { for platform, mapping := range c.ModelMapping {
cp.ModelMapping[k] = v inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
} }
} }
return &cp return &cp
......
...@@ -39,6 +39,9 @@ type ChannelRepository interface { ...@@ -39,6 +39,9 @@ type ChannelRepository interface {
GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error)
GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error)
// 分组平台查询
GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error)
// 模型定价 // 模型定价
ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error)
CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error
...@@ -47,18 +50,20 @@ type ChannelRepository interface { ...@@ -47,18 +50,20 @@ type ChannelRepository interface {
ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error
} }
// channelModelKey 渠道缓存复合键 // channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突)
type channelModelKey struct { type channelModelKey struct {
groupID int64 groupID int64
platform string // 平台标识
model string // lowercase model string // lowercase
} }
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) // channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct { type channelCache struct {
// 热路径查找 // 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, model) → 定价 pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, model) → 映射目标 mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
channelByGroupID map[int64]*Channel // groupID → 渠道 channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径(CRUD 操作) // 冷路径(CRUD 操作)
byID map[int64]*Channel byID map[int64]*Channel
...@@ -135,6 +140,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -135,6 +140,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string), mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel), byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
} }
...@@ -142,10 +148,25 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -142,10 +148,25 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
} }
// 收集所有 groupID,批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
// 降级:继续构建缓存但无法按平台过滤
}
}
cache := &channelCache{ cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string), mappingByGroupModel: make(map[channelModelKey]string),
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)), byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(), loadedAt: time.Now(),
} }
...@@ -157,23 +178,29 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -157,23 +178,29 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// 展开到分组维度 // 展开到分组维度
for _, gid := range ch.GroupIDs { for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid] // e.g. "anthropic"
// 展开模型定价到 (groupID, model) → *ChannelModelPricing // 展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing
for j := range ch.ModelPricing { for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j] pricing := &ch.ModelPricing[j]
if pricing.Platform != platform {
continue // 跳过非本平台的定价
}
for _, model := range pricing.Models { for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, model: strings.ToLower(model)} key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing cache.pricingByGroupModel[key] = pricing
} }
} }
// 展开模型映射到 (groupID, model) → target // 只展开该平台的模型映射到 (groupID, platform, model) → target
for src, dst := range ch.ModelMapping { if platformMapping, ok := ch.ModelMapping[platform]; ok {
key := channelModelKey{groupID: gid, model: strings.ToLower(src)} for src, dst := range platformMapping {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst cache.mappingByGroupModel[key] = dst
} }
} }
} }
}
s.cache.Store(cache) s.cache.Store(cache)
return cache, nil return cache, nil
...@@ -214,7 +241,8 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int ...@@ -214,7 +241,8 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
return nil return nil
} }
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key] pricing, ok := cache.pricingByGroupModel[key]
if !ok { if !ok {
return nil return nil
...@@ -246,7 +274,8 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -246,7 +274,8 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
result.BillingModelSource = BillingModelSourceRequested result.BillingModelSource = BillingModelSourceRequested
} }
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
if mapped, ok := cache.mappingByGroupModel[key]; ok { if mapped, ok := cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped result.MappedModel = mapped
result.Mapped = true result.Mapped = true
...@@ -270,7 +299,8 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m ...@@ -270,7 +299,8 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
} }
// 检查模型是否在定价列表中 // 检查模型是否在定价列表中
key := channelModelKey{groupID: groupID, model: strings.ToLower(model)} platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key] _, exists := cache.pricingByGroupModel[key]
return !exists return !exists
} }
...@@ -458,7 +488,7 @@ type CreateChannelInput struct { ...@@ -458,7 +488,7 @@ type CreateChannelInput struct {
Description string Description string
GroupIDs []int64 GroupIDs []int64
ModelPricing []ChannelModelPricing ModelPricing []ChannelModelPricing
ModelMapping map[string]string ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string BillingModelSource string
RestrictModels bool RestrictModels bool
} }
...@@ -470,7 +500,7 @@ type UpdateChannelInput struct { ...@@ -470,7 +500,7 @@ type UpdateChannelInput struct {
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string BillingModelSource string
RestrictModels *bool RestrictModels *bool
} }
-- 086_channel_platform_pricing.sql
-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式
-- 1. channel_model_pricing 加 platform 列
ALTER TABLE channel_model_pricing
ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
ON channel_model_pricing (platform);
-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
UPDATE channels
SET model_mapping = jsonb_build_object('anthropic', model_mapping)
WHERE model_mapping IS NOT NULL
AND model_mapping::text NOT IN ('{}', 'null', '')
AND NOT EXISTS (
SELECT 1 FROM jsonb_each(model_mapping) AS kv
WHERE jsonb_typeof(kv.value) = 'object'
LIMIT 1
);
...@@ -22,6 +22,7 @@ export interface PricingInterval { ...@@ -22,6 +22,7 @@ export interface PricingInterval {
export interface ChannelModelPricing { export interface ChannelModelPricing {
id?: number id?: number
platform: string
models: string[] models: string[]
billing_mode: BillingMode billing_mode: BillingMode
input_price: number | null input_price: number | null
...@@ -42,7 +43,7 @@ export interface Channel { ...@@ -42,7 +43,7 @@ export interface Channel {
restrict_models: boolean restrict_models: boolean
group_ids: number[] group_ids: number[]
model_pricing: ChannelModelPricing[] model_pricing: ChannelModelPricing[]
model_mapping: Record<string, string> model_mapping: Record<string, Record<string, string>> // platform → {src→dst}
created_at: string created_at: string
updated_at: string updated_at: string
} }
...@@ -52,7 +53,7 @@ export interface CreateChannelRequest { ...@@ -52,7 +53,7 @@ export interface CreateChannelRequest {
description?: string description?: string
group_ids?: number[] group_ids?: number[]
model_pricing?: ChannelModelPricing[] model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, string> model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string billing_model_source?: string
restrict_models?: boolean restrict_models?: boolean
} }
...@@ -63,7 +64,7 @@ export interface UpdateChannelRequest { ...@@ -63,7 +64,7 @@ export interface UpdateChannelRequest {
status?: string status?: string
group_ids?: number[] group_ids?: number[]
model_pricing?: ChannelModelPricing[] model_pricing?: ChannelModelPricing[]
model_mapping?: Record<string, string> model_mapping?: Record<string, Record<string, string>>
billing_model_source?: string billing_model_source?: string
restrict_models?: boolean restrict_models?: boolean
} }
......
...@@ -1806,7 +1806,13 @@ export default { ...@@ -1806,7 +1806,13 @@ export default {
restrictModels: 'Restrict Models', restrictModels: 'Restrict Models',
restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.', restrictModelsHint: 'When enabled, only models in the pricing list are allowed. Others will be rejected.',
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)', defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
defaultImagePrice: 'Default image price (fallback when no tier matches)' defaultImagePrice: 'Default image price (fallback when no tier matches)',
platformConfig: 'Platform Configuration',
addPlatform: 'Add Platform',
noPlatforms: 'Click "Add Platform" to start configuring the channel',
mappingCount: 'mappings',
pricingEntry: 'Pricing Entry',
noModels: 'No models added'
} }
}, },
......
...@@ -1886,7 +1886,13 @@ export default { ...@@ -1886,7 +1886,13 @@ export default {
restrictModels: '限制模型', restrictModels: '限制模型',
restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。', restrictModelsHint: '开启后,仅允许模型定价列表中的模型。不在列表中的模型请求将被拒绝。',
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)', defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
defaultImagePrice: '默认图片价格(未命中层级时使用)' defaultImagePrice: '默认图片价格(未命中层级时使用)',
platformConfig: '平台配置',
addPlatform: '添加平台',
noPlatforms: '点击"添加平台"开始配置渠道',
mappingCount: '条映射',
pricingEntry: '定价配置',
noModels: '未添加模型'
} }
}, },
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment