Commit f585a15e authored by shaw's avatar shaw
Browse files

fix(billing): prevent channel_mapped override from reverting BillingModel when channel did not map

When a channel has no model mapping for the requested model, ChannelMappedModel
equals OriginalModel (the user's arbitrary input). Combined with the default
BillingModelSource="channel_mapped", this incorrectly overrides the BillingModel
set by the OpenAI format conversion layer (e.g., gpt-5.4 from DefaultMappedModel)
back to the unmapped original model (e.g., glm) which has no pricing — resulting
in zero-cost billing.

Add guard condition so the channel_mapped override only fires when the channel
actually changed the model (ChannelMappedModel != OriginalModel).
parent dbb248df
......@@ -933,6 +933,89 @@ func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
}
func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingModelWhenUnmapped(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
// When channel did NOT map the model (ChannelMappedModel == OriginalModel),
// billing should use result.BillingModel (the actual model used after group
// DefaultMappedModel resolution), not the unmapped original model.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_channel_unmapped_billing",
Model: "glm",
BillingModel: "gpt-5.1",
UpstreamModel: "gpt-5.1",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
ChannelUsageFields: ChannelUsageFields{
ChannelID: 1,
OriginalModel: "glm",
ChannelMappedModel: "glm", // channel did NOT map
BillingModelSource: BillingModelSourceChannelMapped,
},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenMapped(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
// When channel DID map the model (ChannelMappedModel != OriginalModel),
// billing should use the channel-mapped model, honoring admin intent.
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
InputTokens: 20,
OutputTokens: 10,
}, 1.1)
require.NoError(t, err)
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_channel_mapped_billing",
Model: "glm",
BillingModel: "gpt-5.1-codex",
UpstreamModel: "gpt-5.1-codex",
Usage: usage,
Duration: time.Second,
},
APIKey: &APIKey{ID: 10},
User: &User{ID: 20},
Account: &Account{ID: 30},
ChannelUsageFields: ChannelUsageFields{
ChannelID: 1,
OriginalModel: "glm",
ChannelMappedModel: "gpt-5.1", // channel mapped glm → gpt-5.1
BillingModelSource: BillingModelSourceChannelMapped,
},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
......
......@@ -4277,7 +4277,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" && input.ChannelMappedModel != input.OriginalModel {
billingModel = input.ChannelMappedModel
}
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment