Commit b4a42a64 authored by erio's avatar erio
Browse files

refactor: extract helpers to reduce duplication and function length in gateway billing

- Extract resolveChannelPricing to DRY the resolver pattern shared by calculateImageCost/calculateTokenCost
- Remove unnecessary IIFE wrapper and pass accountRateMultiplier as parameter
- Extract resolveBillingMode, resolveMediaType, optionalSubscriptionID to simplify buildRecordUsageLog (104→65 lines)
- Extract shouldDeductAPIKeyQuota/shouldUpdateRateLimits/shouldUpdateAccountQuota methods on postUsageBillingParams to unify duplicated billing conditions
parent 58b26cb4
......@@ -7451,6 +7451,18 @@ type postUsageBillingParams struct {
APIKeyService APIKeyQuotaUpdater
}
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
}
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
......@@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
......@@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.BalanceCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if p.shouldDeductAPIKeyQuota() {
cmd.APIKeyQuotaCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if p.shouldUpdateRateLimits() {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
}
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
if p.shouldUpdateAccountQuota() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
}
......@@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
// 创建使用日志
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts)
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
......@@ -7890,9 +7903,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
requestID := usageLog.RequestID
accountRateMultiplier := account.BillingRateMultiplier()
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
......@@ -7903,8 +7914,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
......@@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost(
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
if s.resolver == nil || apiKey.Group == nil {
return nil
}
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
return resolved
}
return nil
}
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
func (s *GatewayService) calculateImageCost(
ctx context.Context,
......@@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string,
multiplier float64,
) *CostBreakdown {
hasChannelPricing := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
hasChannelPricing = true
}
}
if hasChannelPricing {
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
......@@ -8036,12 +8051,9 @@ func (s *GatewayService) calculateTokenCost(
var cost *CostBreakdown
var err error
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
// 优先尝试渠道定价 → CalculateCostUnified
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
if resolved.Source == PricingSourceChannel {
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
......@@ -8051,11 +8063,7 @@ func (s *GatewayService) calculateTokenCost(
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
}
if !useUnified {
if opts.LongContextThreshold > 0 {
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(
billingModel, tokens, multiplier,
......@@ -8064,7 +8072,6 @@ func (s *GatewayService) calculateTokenCost(
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
}
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
return &CostBreakdown{ActualCost: 0}
......@@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog(
subscription *UserSubscription,
requestedModel string,
multiplier float64,
accountRateMultiplier float64,
billingType int8,
cacheTTLOverridden bool,
cost *CostBreakdown,
opts *recordUsageOpts,
) *UsageLog {
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
var mediaType *string
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
mediaType = &result.MediaType
}
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
......@@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
GroupID: apiKey.GroupID,
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(),
}
if cost != nil {
......@@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog(
usageLog.ActualCost = cost.ActualCost
}
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过
return usageLog
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if !isSoraMedia {
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else if result.ImageCount > 0 {
billingMode := string(BillingModeImage)
usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
}
if isSoraMedia {
return nil
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
var mode string
switch {
case cost != nil && cost.BillingMode != "":
mode = cost.BillingMode
case result.ImageCount > 0:
mode = string(BillingModeImage)
default:
mode = string(BillingModeToken)
}
return &mode
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
return &subscription.ID
}
return usageLog
return nil
}
// ResolveChannelMapping 委托渠道服务解析模型映射
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment