Commit 0d241d52 authored by erio's avatar erio
Browse files

refactor: replace magic strings with named constants

- PricingSourceChannel/LiteLLM/Fallback for resolver source
- MediaTypeImage/Video/Prompt for result.MediaType
- Reuse BillingModeToken/BillingModeImage for billing mode
- Reuse BillingModelSourceChannelMapped/PlatformAnthropic in handler
parent 212eaa3a
...@@ -130,7 +130,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -130,7 +130,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
resp.BillingModelSource = ch.BillingModelSource resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" { if resp.BillingModelSource == "" {
resp.BillingModelSource = "channel_mapped" resp.BillingModelSource = service.BillingModelSourceChannelMapped
} }
if resp.GroupIDs == nil { if resp.GroupIDs == nil {
resp.GroupIDs = []int64{} resp.GroupIDs = []int64{}
...@@ -147,11 +147,11 @@ func channelToResponse(ch *service.Channel) *channelResponse { ...@@ -147,11 +147,11 @@ func channelToResponse(ch *service.Channel) *channelResponse {
} }
billingMode := string(p.BillingMode) billingMode := string(p.BillingMode)
if billingMode == "" { if billingMode == "" {
billingMode = "token" billingMode = string(service.BillingModeToken)
} }
platform := p.Platform platform := p.Platform
if platform == "" { if platform == "" {
platform = "anthropic" platform = service.PlatformAnthropic
} }
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals { for _, iv := range p.Intervals {
...@@ -194,7 +194,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -194,7 +194,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
} }
platform := r.Platform platform := r.Platform
if platform == "" { if platform == "" {
platform = "anthropic" platform = service.PlatformAnthropic
} }
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {
......
...@@ -60,6 +60,19 @@ const ( ...@@ -60,6 +60,19 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
...@@ -7744,7 +7757,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7744,7 +7757,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
// 根据请求类型选择计费方式 // 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" { if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
var soraConfig *SoraPriceConfig var soraConfig *SoraPriceConfig
if apiKey.Group != nil { if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{ soraConfig = &SoraPriceConfig{
...@@ -7754,12 +7767,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7754,12 +7767,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
} }
} }
if result.MediaType == "image" { if result.MediaType == MediaTypeImage {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else { } else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
} }
} else if result.MediaType == "prompt" { } else if result.MediaType == MediaTypePrompt {
cost = &CostBreakdown{} cost = &CostBreakdown{}
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本) // 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
...@@ -7767,7 +7780,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7767,7 +7780,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
hasChannelPricing = true hasChannelPricing = true
} }
} }
...@@ -7900,15 +7913,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7900,15 +7913,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
// 设置计费模式 // 设置计费模式
if result.MediaType != "image" && result.MediaType != "video" && result.MediaType != "prompt" { if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
if cost != nil && cost.BillingMode != "" { if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
billingMode := "image" billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
} }
...@@ -8038,7 +8051,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8038,7 +8051,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
hasChannelPricing = true hasChannelPricing = true
} }
} }
...@@ -8094,7 +8107,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8094,7 +8107,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Model: billingModel, Model: billingModel,
GroupID: &gid, GroupID: &gid,
}) })
if resolved.Source == "channel" { if resolved.Source == PricingSourceChannel {
// 有渠道定价,渠道区间已包含上下文分层 // 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{ cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
...@@ -8179,10 +8192,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8179,10 +8192,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 { } else if result.ImageCount > 0 {
billingMode := "image" billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
......
...@@ -5,6 +5,13 @@ import ( ...@@ -5,6 +5,13 @@ import (
"log/slog" "log/slog"
) )
// PricingSource 定价来源标识
const (
PricingSourceChannel = "channel"
PricingSourceLiteLLM = "litellm"
PricingSourceFallback = "fallback"
)
// ResolvedPricing 统一定价解析结果 // ResolvedPricing 统一定价解析结果
type ResolvedPricing struct { type ResolvedPricing struct {
// Mode 计费模式 // Mode 计费模式
...@@ -78,9 +85,9 @@ func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, ...@@ -78,9 +85,9 @@ func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing,
if err != nil { if err != nil {
slog.Debug("failed to get model pricing from LiteLLM, using fallback", slog.Debug("failed to get model pricing from LiteLLM, using fallback",
"model", model, "error", err) "model", model, "error", err)
return nil, "fallback" return nil, PricingSourceFallback
} }
return pricing, "litellm" return pricing, PricingSourceLiteLLM
} }
// applyChannelOverrides 应用渠道定价覆盖 // applyChannelOverrides 应用渠道定价覆盖
...@@ -90,7 +97,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI ...@@ -90,7 +97,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI
return return
} }
resolved.Source = "channel" resolved.Source = PricingSourceChannel
resolved.Mode = chPricing.BillingMode resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" { if resolved.Mode == "" {
resolved.Mode = BillingModeToken resolved.Mode = BillingModeToken
......
...@@ -4290,7 +4290,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4290,7 +4290,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
billingMode := cost.BillingMode billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
// 添加 UserAgent // 添加 UserAgent
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment