Commit 58b26cb4 authored by erio's avatar erio
Browse files

refactor: merge RecordUsage and RecordUsageWithLongContext into shared core

- Extract recordUsageCore with recordUsageOpts for parameterized differences
- RecordUsage (276 lines) → thin wrapper (~40 lines)
- RecordUsageWithLongContext (251 lines) → thin wrapper (~20 lines)
- Split billing logic into calculateSoraMediaCost, calculateImageCost,
  calculateTokenCost sub-functions
- Extract buildRecordUsageLog for usage log construction
- Net reduction: -79 lines, eliminated ~170 lines of duplication
parent b453c327
......@@ -7706,8 +7706,109 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
}
}
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
LongContextThreshold int
LongContextMultiplier float64
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
return s.recordUsageCore(ctx, &recordUsageCoreInput{
Result: input.Result,
APIKey: input.APIKey,
User: input.User,
Account: input.Account,
Subscription: input.Subscription,
InboundEndpoint: input.InboundEndpoint,
UpstreamEndpoint: input.UpstreamEndpoint,
UserAgent: input.UserAgent,
IPAddress: input.IPAddress,
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true,
})
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type RecordUsageLongContextInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
InboundEndpoint string // 入站端点(客户端请求路径)
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
return s.recordUsageCore(ctx, &recordUsageCoreInput{
Result: input.Result,
APIKey: input.APIKey,
User: input.User,
Account: input.Account,
Subscription: input.Subscription,
InboundEndpoint: input.InboundEndpoint,
UpstreamEndpoint: input.UpstreamEndpoint,
UserAgent: input.UserAgent,
IPAddress: input.IPAddress,
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
LongContextThreshold: input.LongContextThreshold,
LongContextMultiplier: input.LongContextMultiplier,
})
}
// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。
type recordUsageCoreInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
InboundEndpoint string
UpstreamEndpoint string
UserAgent string
IPAddress string
RequestPayloadHash string
ForceCacheBilling bool
APIKeyService APIKeyQuotaUpdater
ChannelUsageFields
}
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
apiKey := input.APIKey
user := input.User
......@@ -7723,9 +7824,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
// Claude Max cache billing policy(仅 Claude 路径启用)
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
simulatedClaudeMax := false
if opts.EnableClaudePath {
var apiKeyGroup *Group
if apiKey != nil {
apiKeyGroup = apiKey.Group
}
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
simulatedClaudeMax = claudeMaxOutcome.Simulated ||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
......@@ -7740,7 +7853,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
......@@ -7756,100 +7868,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
requestedModel = input.OriginalModel
}
// 根据请求类型选择计费方式
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == MediaTypePrompt {
cost = &CostBreakdown{}
} else if result.ImageCount > 0 {
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
hasChannelPricing := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
if hasChannelPricing {
// 渠道定价优先 → 由 CalculateCostUnified 按 resolved.Mode 分发计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
gid := apiKey.Group.ID
var err error
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
} else {
// 无渠道定价 → 走按次计费(默认,兼容旧版本)
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
}
} else {
// Token 计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var err error
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
groupID := &gid
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: groupID,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
}
// 计算费用
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
......@@ -7859,90 +7879,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
CreatedAt: time.Now(),
}
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost
usageLog.ImageOutputCost = cost.ImageOutputCost
usageLog.CacheCreationCost = cost.CacheCreationCost
usageLog.CacheReadCost = cost.CacheReadCost
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
// 设置计费模式
if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
......@@ -7951,6 +7889,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
requestID := usageLog.RequestID
accountRateMultiplier := account.BillingRateMultiplier()
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
......@@ -7974,95 +7914,134 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
type RecordUsageLongContextInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
InboundEndpoint string // 入站端点(客户端请求路径)
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
LongContextThreshold int // 长上下文阈值(如 200000)
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
result := input.Result
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
result.Usage.InputTokens, account.ID)
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
result.Usage.InputTokens = 0
// calculateRecordUsageCost 根据请求类型和选项计算费用。
func (s *GatewayService) calculateRecordUsageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
}
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := 1.0
if s.cfg != nil {
multiplier = s.cfg.Default.RateMultiplier
// Token 计费
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if apiKey.GroupID != nil && apiKey.Group != nil {
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
var cost *CostBreakdown
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
billingModel = input.ChannelMappedModel
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
func (s *GatewayService) calculateImageCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
hasChannelPricing := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
if hasChannelPricing {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
gid := apiKey.Group.ID
cost, err := s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
}
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
func (s *GatewayService) calculateTokenCost(
ctx context.Context,
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
hasChannelPricing := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
if hasChannelPricing {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
gid := apiKey.Group.ID
var err error
var cost *CostBreakdown
var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
......@@ -8072,78 +8051,52 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
RateMultiplier: multiplier,
Resolver: s.resolver,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
} else {
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
}
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
}
var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == PricingSourceChannel {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
}
if !useUnified {
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
useUnified = true
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
if !useUnified {
if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(
billingModel, tokens, multiplier,
opts.LongContextThreshold, opts.LongContextMultiplier,
)
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
}
return cost
}
// 创建使用日志
// buildRecordUsageLog 构建使用日志并设置计费模式。
func (s *GatewayService) buildRecordUsageLog(
ctx context.Context,
input *recordUsageCoreInput,
result *ForwardResult,
apiKey *APIKey,
user *User,
account *Account,
subscription *UserSubscription,
requestedModel string,
multiplier float64,
billingType int8,
cacheTTLOverridden bool,
cost *CostBreakdown,
opts *recordUsageOpts,
) *UsageLog {
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
......@@ -8172,6 +8125,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
......@@ -8187,16 +8141,20 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
usageLog.ActualCost = cost.ActualCost
}
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if !isSoraMedia {
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
}
// 添加 UserAgent
......@@ -8217,34 +8175,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
usageLog.SubscriptionID = &subscription.ID
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil
return usageLog
}
// ResolveChannelMapping 委托渠道服务解析模型映射
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment