Commit 0d241d52 authored by erio's avatar erio
Browse files

refactor: replace magic strings with named constants

- PricingSourceChannel/LiteLLM/Fallback for resolver source
- MediaTypeImage/Video/Prompt for result.MediaType
- Reuse BillingModeToken/BillingModeImage for billing mode
- Reuse BillingModelSourceChannelMapped/PlatformAnthropic in handler
parent 212eaa3a
......@@ -130,7 +130,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
resp.BillingModelSource = "channel_mapped"
resp.BillingModelSource = service.BillingModelSourceChannelMapped
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
......@@ -147,11 +147,11 @@ func channelToResponse(ch *service.Channel) *channelResponse {
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = "token"
billingMode = string(service.BillingModeToken)
}
platform := p.Platform
if platform == "" {
platform = "anthropic"
platform = service.PlatformAnthropic
}
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
for _, iv := range p.Intervals {
......@@ -194,7 +194,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
}
platform := r.Platform
if platform == "" {
platform = "anthropic"
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals {
......
......@@ -60,6 +60,19 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
......@@ -7744,7 +7757,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
......@@ -7754,12 +7767,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == "image" {
if result.MediaType == MediaTypeImage {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == "prompt" {
} else if result.MediaType == MediaTypePrompt {
cost = &CostBreakdown{}
} else if result.ImageCount > 0 {
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
......@@ -7767,7 +7780,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" {
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
......@@ -7900,15 +7913,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 设置计费模式
if result.MediaType != "image" && result.MediaType != "video" && result.MediaType != "prompt" {
if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := "image"
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
}
......@@ -8038,7 +8051,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == "channel" {
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
......@@ -8094,7 +8107,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == "channel" {
if resolved.Source == PricingSourceChannel {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
......@@ -8179,10 +8192,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := "image"
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
......
......@@ -5,6 +5,13 @@ import (
"log/slog"
)
// PricingSource 定价来源标识
const (
PricingSourceChannel = "channel"
PricingSourceLiteLLM = "litellm"
PricingSourceFallback = "fallback"
)
// ResolvedPricing 统一定价解析结果
type ResolvedPricing struct {
// Mode 计费模式
......@@ -78,9 +85,9 @@ func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing,
if err != nil {
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
"model", model, "error", err)
return nil, "fallback"
return nil, PricingSourceFallback
}
return pricing, "litellm"
return pricing, PricingSourceLiteLLM
}
// applyChannelOverrides 应用渠道定价覆盖
......@@ -90,7 +97,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI
return
}
resolved.Source = "channel"
resolved.Source = PricingSourceChannel
resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" {
resolved.Mode = BillingModeToken
......
......@@ -4290,7 +4290,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
// 添加 UserAgent
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment