"git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "0170d19fa7d9fdb5467dfbecb2dcef3372423066"
Commit 4edcfe1f authored by Ethan0x0000's avatar Ethan0x0000
Browse files

fix(usage): preserve requested model in gateway billing paths

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent

)
Co-authored-by: default avatarSisyphus <clio-agent@sisyphuslabs.ai>
parent 9259dcb6
......@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
mappedModel := "claude-sonnet-4-20250514"
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_models_split",
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
Model: "claude-sonnet-4",
UpstreamModel: mappedModel,
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -482,10 +482,12 @@ type ClaudeUsage struct {
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
RequestID string
Usage ClaudeUsage
Model string
// UpstreamModel is the actual upstream model after mapping.
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
UpstreamModel string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
......@@ -7516,6 +7518,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.MediaType == "image" || result.MediaType == "video" {
......@@ -7531,7 +7534,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.MediaType == "image" {
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
} else {
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
} else if result.MediaType == "prompt" {
cost = &CostBreakdown{}
......@@ -7545,7 +7548,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费
tokens := UsageTokens{
......@@ -7557,7 +7560,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
......@@ -7589,6 +7592,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
......@@ -7719,6 +7723,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
var cost *CostBreakdown
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
......@@ -7731,7 +7736,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
......@@ -7743,7 +7748,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
......@@ -7771,6 +7776,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
......
......@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
require.NotNil(t, usageRepo.lastLog.ServiceTier)
......@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_upstream_model_billing_fallback",
Model: "gpt-5.1",
UpstreamModel: "gpt-5.1-codex",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
billingModel := result.Model
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = result.BillingModel
billingModel = strings.TrimSpace(result.BillingModel)
}
serviceTier := ""
if result.ServiceTier != nil {
......@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
RequestedModel: result.Model,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment