Commit 3cd398b0 authored by erio's avatar erio
Browse files

refactor: extract computeTokenBreakdown to deduplicate billing logic

- calculateTokenCost reduced from 80 to 15 lines
- calculateCostInternal reduced from 91 to 15 lines
- Shared logic in computeTokenBreakdown + computeCacheCreationCost
- Unified rateMultiplier <= 0 protection in both paths
parent d3127b8e
...@@ -469,76 +469,97 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos ...@@ -469,76 +469,97 @@ func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input Cos
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing) pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
breakdown := &CostBreakdown{} // 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层)
inputPricePerToken := pricing.InputPricePerToken applyLongCtx := len(resolved.Intervals) == 0
outputPricePerToken := pricing.OutputPricePerToken
cacheReadPricePerToken := pricing.CacheReadPricePerToken return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil
}
// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。
// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。
func (s *BillingService) computeTokenBreakdown(
pricing *ModelPricing, tokens UsageTokens,
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
inputPrice := pricing.InputPricePerToken
outputPrice := pricing.OutputPricePerToken
cacheReadPrice := pricing.CacheReadPricePerToken
tierMultiplier := 1.0 tierMultiplier := 1.0
if usePriorityServiceTierPricing(input.ServiceTier, pricing) { if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 { if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority inputPrice = pricing.InputPricePerTokenPriority
} }
if pricing.OutputPricePerTokenPriority > 0 { if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority outputPrice = pricing.OutputPricePerTokenPriority
} }
if pricing.CacheReadPricePerTokenPriority > 0 { if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority cacheReadPrice = pricing.CacheReadPricePerTokenPriority
} }
} else { } else {
tierMultiplier = serviceTierCostMultiplier(input.ServiceTier) tierMultiplier = serviceTierCostMultiplier(serviceTier)
} }
// 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层) if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) {
if len(resolved.Intervals) == 0 && s.shouldApplySessionLongContextPricing(input.Tokens, pricing) { inputPrice *= pricing.LongContextInputMultiplier
inputPricePerToken *= pricing.LongContextInputMultiplier outputPrice *= pricing.LongContextOutputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
} }
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken bd := &CostBreakdown{}
bd.InputCost = float64(tokens.InputTokens) * inputPrice
// Separate image output tokens from text output tokens // 分离图片输出 token 与文本输出 token
textOutputTokens := input.Tokens.OutputTokens - input.Tokens.ImageOutputTokens textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 { if textOutputTokens < 0 {
textOutputTokens = 0 textOutputTokens = 0
} }
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken bd.OutputCost = float64(textOutputTokens) * outputPrice
// Image output tokens cost (separate rate from text output) // 图片输出 token 费用(独立费率)
if input.Tokens.ImageOutputTokens > 0 { if tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken imgPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 { if imgPrice == 0 {
imageOutputPrice = outputPricePerToken // fallback to regular output price imgPrice = outputPrice // 回退到常规输出价格
} }
breakdown.ImageOutputCost = float64(input.Tokens.ImageOutputTokens) * imageOutputPrice bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice
} }
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { // 缓存创建费用
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 { bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens)
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(input.Tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
} else {
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(input.Tokens.CacheReadTokens) * cacheReadPricePerToken bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice
if tierMultiplier != 1.0 { if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier bd.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier bd.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier bd.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier bd.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier bd.CacheReadCost *= tierMultiplier
} }
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost + bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost bd.CacheCreationCost + bd.CacheReadCost
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier bd.ActualCost = bd.TotalCost * rateMultiplier
return breakdown, nil return bd
}
// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。
func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 {
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
}
return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
} }
// calculatePerRequestCost 按次/图片计费 // calculatePerRequestCost 按次/图片计费
...@@ -594,84 +615,8 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, ...@@ -594,84 +615,8 @@ func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens,
return nil, err return nil, err
} }
breakdown := &CostBreakdown{} // 旧路径始终检查长上下文定价(无区间定价概念)
inputPricePerToken := pricing.InputPricePerToken return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil
outputPricePerToken := pricing.OutputPricePerToken
cacheReadPricePerToken := pricing.CacheReadPricePerToken
tierMultiplier := 1.0
if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority
}
if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority
}
if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
}
} else {
tierMultiplier = serviceTierCostMultiplier(serviceTier)
}
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
}
// 计算输入token费用(使用per-token价格)
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用(分离图片输出token)
textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens
if textOutputTokens < 0 {
textOutputTokens = 0
}
breakdown.OutputCost = float64(textOutputTokens) * outputPricePerToken
// 图片输出 token 费用
if tokens.ImageOutputTokens > 0 {
imageOutputPrice := pricing.ImageOutputPricePerToken
if imageOutputPrice == 0 {
imageOutputPrice = outputPricePerToken
}
breakdown.ImageOutputCost = float64(tokens.ImageOutputTokens) * imageOutputPrice
}
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
} else {
// 标准缓存创建价格(per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.ImageOutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
}
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.ImageOutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
return breakdown, nil
} }
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment