package service import ( "fmt" "sort" "strings" "time" ) // BillingMode 计费模式 type BillingMode string const ( BillingModeToken BillingMode = "token" // 按 token 区间计费 BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) ) // IsValid 检查 BillingMode 是否为合法值 func (m BillingMode) IsValid() bool { switch m { case BillingModeToken, BillingModePerRequest, BillingModeImage, "": return true } return false } const ( BillingModelSourceRequested = "requested" BillingModelSourceUpstream = "upstream" BillingModelSourceChannelMapped = "channel_mapped" ) // Channel 渠道实体 type Channel struct { ID int64 Name string Description string Status string BillingModelSource string // "requested", "upstream", or "channel_mapped" RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) CreatedAt time.Time UpdatedAt time.Time // 关联的分组 ID 列表 GroupIDs []int64 // 模型定价列表(每条含 Platform 字段) ModelPricing []ChannelModelPricing // 渠道级模型映射(按平台分组:platform → {src→dst}) ModelMapping map[string]map[string]string } // ChannelModelPricing 渠道模型定价条目 type ChannelModelPricing struct { ID int64 ChannelID int64 Platform string // 所属平台(anthropic/openai/gemini/...) Models []string // 绑定的模型列表 BillingMode BillingMode // 计费模式 InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 OutputPrice *float64 // 每 token 输出价格(USD) CacheWritePrice *float64 // 缓存写入价格 CacheReadPrice *float64 // 缓存读取价格 ImageOutputPrice *float64 // 图片输出价格(向后兼容) PerRequestPrice *float64 // 默认按次计费价格(USD) Intervals []PricingInterval // 区间定价列表 CreatedAt time.Time UpdatedAt time.Time } // PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层) type PricingInterval struct { ID int64 PricingID int64 MinTokens int // 区间下界(含) MaxTokens *int // 区间上界(不含),nil = 无上限 TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等) InputPrice *float64 // token 模式:每 token 输入价 OutputPrice *float64 // token 模式:每 token 输出价 CacheWritePrice *float64 // token 模式:缓存写入价 CacheReadPrice *float64 // token 模式:缓存读取价 PerRequestPrice *float64 // 按次/图片模式:每次请求价格 SortOrder int CreatedAt time.Time UpdatedAt time.Time } // IsActive 判断渠道是否启用 func (c *Channel) IsActive() bool { return c.Status == StatusActive } // GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 // 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。 func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { modelLower := strings.ToLower(model) for i := range c.ModelPricing { for _, m := range c.ModelPricing[i].Models { if strings.ToLower(m) == modelLower { cp := c.ModelPricing[i].Clone() return &cp } } } return nil } // FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。 // 区间为左开右闭 (min, max]:min 不含,max 包含。 // 第一个区间 min=0 时,0 token 不匹配任何区间(回退到默认价格)。 func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval { for i := range intervals { iv := &intervals[i] if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) { return iv } } return nil } // GetIntervalForContext 根据总 context token 数查找匹配的区间。 func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval { return FindMatchingInterval(p.Intervals, totalTokens) } // GetTierByLabel 根据标签查找层级(用于 per_request / image 模式) func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval { labelLower := strings.ToLower(label) for i := range p.Intervals { if strings.ToLower(p.Intervals[i].TierLabel) == labelLower { return &p.Intervals[i] } } return nil } // Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全) func (p ChannelModelPricing) Clone() ChannelModelPricing { cp := p if p.Models != nil { cp.Models = make([]string, len(p.Models)) copy(cp.Models, p.Models) } if p.Intervals != nil { cp.Intervals = make([]PricingInterval, len(p.Intervals)) copy(cp.Intervals, p.Intervals) } return cp } // Clone 返回 Channel 的深拷贝 func (c *Channel) Clone() *Channel { if c == nil { return nil } cp := *c if c.GroupIDs != nil { cp.GroupIDs = make([]int64, len(c.GroupIDs)) copy(cp.GroupIDs, c.GroupIDs) } if c.ModelPricing != nil { cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing)) for i := range c.ModelPricing { cp.ModelPricing[i] = c.ModelPricing[i].Clone() } } if c.ModelMapping != nil { cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping)) for platform, mapping := range c.ModelMapping { inner := make(map[string]string, len(mapping)) for k, v := range mapping { inner[k] = v } cp.ModelMapping[platform] = inner } } return &cp } // ValidateIntervals 校验区间列表的合法性。 // 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; // 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); // 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。 func ValidateIntervals(intervals []PricingInterval) error { if len(intervals) == 0 { return nil } sorted := make([]PricingInterval, len(intervals)) copy(sorted, intervals) sort.Slice(sorted, func(i, j int) bool { return sorted[i].MinTokens < sorted[j].MinTokens }) for i := range sorted { if err := validateSingleInterval(&sorted[i], i); err != nil { return err } } return validateIntervalOverlap(sorted) } // validateSingleInterval 校验单个区间的字段合法性 func validateSingleInterval(iv *PricingInterval, idx int) error { if iv.MinTokens < 0 { return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens) } if iv.MaxTokens != nil { if *iv.MaxTokens <= 0 { return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens) } if *iv.MaxTokens <= iv.MinTokens { return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)", idx+1, *iv.MaxTokens, iv.MinTokens) } } return validateIntervalPrices(iv, idx) } // validateIntervalPrices 校验区间内所有价格字段 >= 0 func validateIntervalPrices(iv *PricingInterval, idx int) error { prices := []struct { name string val *float64 }{ {"input_price", iv.InputPrice}, {"output_price", iv.OutputPrice}, {"cache_write_price", iv.CacheWritePrice}, {"cache_read_price", iv.CacheReadPrice}, {"per_request_price", iv.PerRequestPrice}, } for _, p := range prices { if p.val != nil && *p.val < 0 { return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name) } } return nil } // validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后 func validateIntervalOverlap(sorted []PricingInterval) error { for i, iv := range sorted { // 无界区间必须是最后一个 if iv.MaxTokens == nil && i < len(sorted)-1 { return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one", i+1) } if i == 0 { continue } prev := sorted[i-1] // 检查重叠:前一个区间的上界 > 当前区间的下界则重叠 // (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max] if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens { return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d", i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens) } } return nil } func formatMaxTokensLabel(max *int) string { if max == nil { return "∞" } return fmt.Sprintf("%d", *max) } // ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) type ChannelUsageFields struct { ChannelID int64 // 渠道 ID(0 = 无渠道) OriginalModel string // 用户原始请求模型(渠道映射前) ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel) BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" ModelMappingChain string // 映射链描述,如 "a→b→c" }