Commit 611fd884 authored by ius's avatar ius
Browse files

feat: decouple billing correctness from usage log batching

parent c9debc50
......@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
......
......@@ -136,12 +136,14 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
},
}
svc := &GatewayService{
cfg: &config.Config{
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
......@@ -221,12 +223,14 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
},
}
svc := &GatewayService{
cfg: &config.Config{
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
......@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require.Equal(t, 5, result.usage.OutputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_ = pr.Close()
<-done
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens)
......@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
......@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens)
......
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return NewGatewayService(
nil,
nil,
usageRepo,
nil,
userRepo,
subRepo,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
nil,
&DeferredService{},
nil,
nil,
nil,
nil,
nil,
)
}
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.usageBillingRepo = billingRepo
return svc
}
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 501,
Quota: 100,
},
User: &User{ID: 601},
Account: &Account{ID: 701},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_hash",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_fallback",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_not_persisted",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 503,
Quota: 100,
},
User: &User{ID: 603},
Account: &Account{ID: 703},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
Result: &ForwardResult{
RequestID: "gateway_long_context_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 12,
OutputTokens: 8,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 502,
Quota: 100,
},
User: &User{ID: 602},
Account: &Account{ID: 702},
LongContextThreshold: 200000,
LongContextMultiplier: 2,
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 504},
User: &User{ID: 604},
Account: &Account{ID: 704},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_billing_fail",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 505},
User: &User{ID: 605},
Account: &Account{ID: 705},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
This diff is collapsed.
......@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
......
......@@ -7,6 +7,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
......@@ -17,25 +18,52 @@ type openAIRecordUsageLogRepoStub struct {
err error
calls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.inserted, s.err
}
type openAIRecordUsageBillingRepoStub struct {
UsageBillingRepository
result *UsageBillingApplyResult
err error
calls int
lastCmd *UsageBillingCommand
lastCtxErr error
}
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
s.calls++
s.lastCmd = cmd
s.lastCtxErr = ctx.Err()
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &UsageBillingApplyResult{Applied: true}, nil
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
lastCtxErr error
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
s.lastCtxErr = ctx.Err()
return s.deductErr
}
......@@ -44,10 +72,12 @@ type openAIRecordUsageSubRepoStub struct {
incrementCalls int
incrementErr error
lastCtxErr error
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
s.lastCtxErr = ctx.Err()
return s.incrementErr
}
......@@ -56,17 +86,21 @@ type openAIRecordUsageAPIKeyQuotaStub struct {
rateLimitCalls int
err error
lastAmount float64
lastQuotaCtxErr error
lastRateLimitCtxErr error
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
s.lastQuotaCtxErr = ctx.Err()
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
s.lastRateLimitCtxErr = ctx.Err()
return s.err
}
......@@ -93,23 +127,38 @@ func i64p(v int64) *int64 {
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
svc := NewOpenAIGatewayService(
nil,
usageRepo,
nil,
userRepo,
subRepo,
rateRepo,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
&DeferredService{},
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
)
return svc
}
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
svc.usageBillingRepo = billingRepo
return svc
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
......@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
......@@ -272,9 +322,252 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate_billing_key",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10045,
Quota: 100,
},
User: &User{ID: 20045},
Account: &Account{ID: 30045},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_usage_log_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10041},
User: &User{ID: 20041},
Account: &Account{ID: 30041},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_not_persisted",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10043,
Quota: 100,
},
User: &User{ID: 20043},
Account: &Account{ID: 30043},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_ctx",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10042,
Quota: 100,
},
User: &User{ID: 20042},
Account: &Account{ID: 30042},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_repo_ctx",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10046},
User: &User{ID: 20046},
Account: &Account{ID: 30046},
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.NoError(t, billingRepo.lastCtxErr)
require.Equal(t, 1, usageRepo.calls)
require.NoError(t, usageRepo.lastCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "openai_payload_hash",
Usage: OpenAIUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "gpt-5",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10047},
User: &User{ID: 20047},
Account: &Account{ID: 30047},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_billing_fail",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10048},
User: &User{ID: 20048},
Account: &Account{ID: 30048},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
......
......@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct {
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageBillingRepo UsageBillingRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
......@@ -295,6 +296,7 @@ type OpenAIGatewayService struct {
func NewOpenAIGatewayService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageBillingRepo UsageBillingRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
......@@ -312,6 +314,7 @@ func NewOpenAIGatewayService(
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
......@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
......@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err
}
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
......@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
var firstTokenMs *int
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body)
......@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
......@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
if sawTerminalEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
......@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
......@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
......@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
}
}
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
if sawTerminalEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
......@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
dataBytes := []byte(data)
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
......@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
return resultWithUsage(), nil
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
......@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
if len(data) < 72 {
return
}
if gjson.GetBytes(data, "type").String() != "response.completed" {
eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" {
return
}
......@@ -3677,6 +3693,7 @@ type OpenAIRecordUsageInput struct {
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
}
......@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
RequestID: requestID,
Model: billingModel,
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
......@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
return nil
}
......
......@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
......@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
t.Fatalf("expected incomplete stream error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
......@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
}
}
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 2, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
......@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
......@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens)
// done 事件同样可能携带最终 usage
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
require.Equal(t, 13, usage.InputTokens)
require.Equal(t, 15, usage.OutputTokens)
require.Equal(t, 4, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
......
......@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
......
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
)
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
// UsageBillingCommand describes one billable request that must be applied at most once.
type UsageBillingCommand struct {
RequestID string
APIKeyID int64
RequestFingerprint string
RequestPayloadHash string
UserID int64
AccountID int64
SubscriptionID *int64
AccountType string
Model string
ServiceTier string
ReasoningEffort string
BillingType int8
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
ImageCount int
MediaType string
BalanceCost float64
SubscriptionCost float64
APIKeyQuotaCost float64
APIKeyRateLimitCost float64
AccountQuotaCost float64
}
func (c *UsageBillingCommand) Normalize() {
if c == nil {
return
}
c.RequestID = strings.TrimSpace(c.RequestID)
if strings.TrimSpace(c.RequestFingerprint) == "" {
c.RequestFingerprint = buildUsageBillingFingerprint(c)
}
}
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
if c == nil {
return ""
}
raw := fmt.Sprintf(
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
c.UserID,
c.AccountID,
c.APIKeyID,
strings.TrimSpace(c.AccountType),
strings.TrimSpace(c.Model),
strings.TrimSpace(c.ServiceTier),
strings.TrimSpace(c.ReasoningEffort),
c.BillingType,
c.InputTokens,
c.OutputTokens,
c.CacheCreationTokens,
c.CacheReadTokens,
c.ImageCount,
strings.TrimSpace(c.MediaType),
valueOrZero(c.SubscriptionID),
c.BalanceCost,
c.SubscriptionCost,
c.APIKeyQuotaCost,
c.APIKeyRateLimitCost,
c.AccountQuotaCost,
)
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
raw += "|" + payloadHash
}
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
func HashUsageRequestPayload(payload []byte) string {
if len(payload) == 0 {
return ""
}
sum := sha256.Sum256(payload)
return hex.EncodeToString(sum[:])
}
func valueOrZero(v *int64) int64 {
if v == nil {
return 0
}
return *v
}
type UsageBillingApplyResult struct {
Applied bool
APIKeyQuotaExhausted bool
}
type UsageBillingRepository interface {
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
}
......@@ -57,6 +57,7 @@ type cleanupRepoStub struct {
type dashboardRepoStub struct {
recomputeErr error
recomputeCalls int
}
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
......@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
}
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.recomputeErr
}
......@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
return nil
}
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
......@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
......@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
}
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
dashboardRepo := &dashboardRepoStub{}
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{deleted: 0},
},
}
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
})
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
......@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
}
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
......
package service
import "errors"
type usageLogCreateDisposition int
const (
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
usageLogCreateDispositionNotPersisted
)
type UsageLogCreateError struct {
err error
disposition usageLogCreateDisposition
}
func (e *UsageLogCreateError) Error() string {
if e == nil || e.err == nil {
return "usage log create error"
}
return e.err.Error()
}
func (e *UsageLogCreateError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
func MarkUsageLogCreateNotPersisted(err error) error {
if err == nil {
return nil
}
return &UsageLogCreateError{
err: err,
disposition: usageLogCreateDispositionNotPersisted,
}
}
func IsUsageLogCreateNotPersisted(err error) bool {
if err == nil {
return false
}
var target *UsageLogCreateError
if !errors.As(err, &target) {
return false
}
return target.disposition == usageLogCreateDispositionNotPersisted
}
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
if inserted {
return true
}
if err == nil {
return false
}
return !IsUsageLogCreateNotPersisted(err)
}
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
-- 幂等执行:可重复运行
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
id BIGSERIAL PRIMARY KEY,
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
ON usage_billing_dedup (request_id, api_key_id);
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
ON usage_billing_dedup
USING BRIN (created_at);
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (request_id, api_key_id)
);
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment