Commit b4a42a64 authored by erio's avatar erio
Browse files

refactor: extract helpers to reduce duplication and function length in gateway billing

- Extract resolveChannelPricing to DRY the resolver pattern shared by calculateImageCost/calculateTokenCost
- Remove unnecessary IIFE wrapper and pass accountRateMultiplier as parameter
- Extract resolveBillingMode, resolveMediaType, optionalSubscriptionID to simplify buildRecordUsageLog (104→65 lines)
- Extract shouldDeductAPIKeyQuota/shouldUpdateRateLimits/shouldUpdateAccountQuota methods on postUsageBillingParams to unify duplicated billing conditions
parent 58b26cb4
...@@ -7451,6 +7451,18 @@ type postUsageBillingParams struct { ...@@ -7451,6 +7451,18 @@ type postUsageBillingParams struct {
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
} }
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑: // postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费 // - 订阅/余额扣费
// - API Key 配额更新 // - API Key 配额更新
...@@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill ...@@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
} }
// 2. API Key 配额 // 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 3. API Key 限速用量 // 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
...@@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage ...@@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.BalanceCost = p.Cost.ActualCost cmd.BalanceCost = p.Cost.ActualCost
} }
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { if p.shouldDeductAPIKeyQuota() {
cmd.APIKeyQuotaCost = p.Cost.ActualCost cmd.APIKeyQuotaCost = p.Cost.ActualCost
} }
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if p.shouldUpdateRateLimits() {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost cmd.APIKeyRateLimitCost = p.Cost.ActualCost
} }
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { if p.shouldUpdateAccountQuota() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
} }
...@@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ...@@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
} }
// 创建使用日志 // 创建使用日志
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts) requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
...@@ -7890,21 +7903,17 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ...@@ -7890,21 +7903,17 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
} }
requestID := usageLog.RequestID requestID := usageLog.RequestID
accountRateMultiplier := account.BillingRateMultiplier() _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
billingErr := func() error { Cost: cost,
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ User: user,
Cost: cost, APIKey: apiKey,
User: user, Account: account,
APIKey: apiKey, Subscription: subscription,
Account: account, RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
Subscription: subscription, IsSubscriptionBill: isSubscriptionBilling,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), AccountRateMultiplier: accountRateMultiplier,
IsSubscriptionBill: isSubscriptionBilling, APIKeyService: input.APIKeyService,
AccountRateMultiplier: accountRateMultiplier, }, s.billingDeps(), s.usageBillingRepo)
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil { if billingErr != nil {
return billingErr return billingErr
...@@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost( ...@@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost(
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
} }
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
if s.resolver == nil || apiKey.Group == nil {
return nil
}
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
return resolved
}
return nil
}
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 // calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
func (s *GatewayService) calculateImageCost( func (s *GatewayService) calculateImageCost(
ctx context.Context, ctx context.Context,
...@@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost( ...@@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string, billingModel string,
multiplier float64, multiplier float64,
) *CostBreakdown { ) *CostBreakdown {
hasChannelPricing := false if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
if hasChannelPricing {
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
...@@ -8036,34 +8051,26 @@ func (s *GatewayService) calculateTokenCost( ...@@ -8036,34 +8051,26 @@ func (s *GatewayService) calculateTokenCost(
var cost *CostBreakdown var cost *CostBreakdown
var err error var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) // 优先尝试渠道定价 → CalculateCostUnified
useUnified := false if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) cost, err = s.billingService.CalculateCostUnified(CostInput{
if resolved.Source == PricingSourceChannel { Ctx: ctx,
cost, err = s.billingService.CalculateCostUnified(CostInput{ Model: billingModel,
Ctx: ctx, GroupID: &gid,
Model: billingModel, Tokens: tokens,
GroupID: &gid, RequestCount: 1,
Tokens: tokens, RateMultiplier: multiplier,
RequestCount: 1, Resolver: s.resolver,
RateMultiplier: multiplier, })
Resolver: s.resolver, } else if opts.LongContextThreshold > 0 {
}) // 长上下文双倍计费(如 Gemini 200K 阈值)
useUnified = true cost, err = s.billingService.CalculateCostWithLongContext(
} billingModel, tokens, multiplier,
} opts.LongContextThreshold, opts.LongContextMultiplier,
if !useUnified { )
if opts.LongContextThreshold > 0 { } else {
// 长上下文双倍计费(如 Gemini 200K 阈值) cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
cost, err = s.billingService.CalculateCostWithLongContext(
billingModel, tokens, multiplier,
opts.LongContextThreshold, opts.LongContextMultiplier,
)
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
} }
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
...@@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog(
subscription *UserSubscription, subscription *UserSubscription,
requestedModel string, requestedModel string,
multiplier float64, multiplier float64,
accountRateMultiplier float64,
billingType int8, billingType int8,
cacheTTLOverridden bool, cacheTTLOverridden bool,
cost *CostBreakdown, cost *CostBreakdown,
opts *recordUsageOpts, opts *recordUsageOpts,
) *UsageLog { ) *UsageLog {
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID) requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
...@@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier, AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: imageSize, ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: mediaType, MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID), ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
GroupID: apiKey.GroupID,
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
if cost != nil { if cost != nil {
...@@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog(
usageLog.ActualCost = cost.ActualCost usageLog.ActualCost = cost.ActualCost
} }
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过 return usageLog
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath && isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if !isSoraMedia { if isSoraMedia {
if cost != nil && cost.BillingMode != "" { return nil
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
} }
var mode string
// 添加 UserAgent switch {
if input.UserAgent != "" { case cost != nil && cost.BillingMode != "":
usageLog.UserAgent = &input.UserAgent mode = cost.BillingMode
case result.ImageCount > 0:
mode = string(BillingModeImage)
default:
mode = string(BillingModeToken)
} }
return &mode
}
// 添加 IPAddress func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if input.IPAddress != "" { if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
usageLog.IPAddress = &input.IPAddress return &result.MediaType
} }
return nil
}
// 添加分组和订阅关联 func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil { if subscription != nil {
usageLog.SubscriptionID = &subscription.ID return &subscription.ID
} }
return nil
return usageLog
} }
// ResolveChannelMapping 委托渠道服务解析模型映射 // ResolveChannelMapping 委托渠道服务解析模型映射
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment